from typing import Tuple, Optional, Callable

from jaxtyping import jaxtyped, Float, Int
import torch
import torch.utils.data as data

import os.path as osp
import os
from PIL import Image
import numpy as np
from beartype import beartype as typechecker

from conf.dataset import CelebAParams
from data.CelebADataset.celebahq_transforms import get_image_transform, get_sketch_transform, get_mask_transform, \
    get_join_transform
from data.UtilsDataset import CustomDataModule
from utils.utils import display_tensor, display_mask


src_face = 'CelebA-HQ-img/'
src_sketch = 'CelebA-HQ-sketch/'
src_sketch2 = 'CelebA-HQ-sketch2/'
src_mask = 'CelebAMask-HQ-mask-anno/'
src_mask_one = 'CelebAMask-HQ-mask-anno_One/'


def fromFilenameGetNumber(filename: str) -> int:
    return int(filename.split('.')[0])


class CelebAHQDataset(data.Dataset):
    def __init__(
        self,
        params: CelebAParams,
    ):
        super(CelebAHQDataset, self).__init__()
        self.ignore_lb = 255
        self.params = params

        self.imgs = sorted(os.listdir(os.path.join(params.root, src_face)), key=fromFilenameGetNumber)
        if params.sketch_version == 1:
            self.sketch_path = src_sketch
        elif params.sketch_version == 2:
            self.sketch_path = src_sketch2
        else:
            raise ValueError(f'Unknown sketch version {params.sketch_version=}')
        self.sketchs = sorted(os.listdir(os.path.join(params.root, self.sketch_path)), key=fromFilenameGetNumber)
        self.masks = sorted(os.listdir(os.path.join(params.root, src_mask_one)), key=fromFilenameGetNumber)

        self.image_transform = get_image_transform(params)
        self.sketch_transform = get_sketch_transform(params)
        self.mask_transform = get_mask_transform(params)
        self.whole_transform = get_join_transform(params)

    @jaxtyped
    @typechecker
    def get_3_dom(self, idx: int) -> Tuple[
        Float[torch.Tensor, '3 h w'],  # RGB
        Float[torch.Tensor, '1 h w'],  # sketch
        Float[torch.Tensor, 'c h w'],  # segmentation mask
    ]:
        img_path = self.imgs[idx]
        img = Image.open(osp.join(self.params.root, src_face, img_path))
        img = self.image_transform(img) if self.image_transform else img

        sketch_path = self.sketchs[idx]
        sketch = Image.open(osp.join(self.params.root, self.sketch_path, sketch_path))
        sketch = self.sketch_transform(sketch) if self.sketch_transform else sketch

        mask_path = self.masks[idx]
        mask = Image.open(osp.join(self.params.root, src_mask_one, mask_path)).convert('P')
        mask = np.array(mask).astype(np.int64)[np.newaxis, :]
        mask = self.mask_transform(mask) if self.mask_transform else mask

        img, sketch, mask = self.whole_transform([img, sketch, mask]) if self.whole_transform else (img, sketch, mask)

        # remove the background class
        if not self.params.return_background:
            mask = mask[1:]

        # img and sketch is in [0, 1], put them in [-1, 1]
        img = img * 2 - 1
        sketch = sketch * 2 - 1

        return img.float(), sketch.float(), mask.float()

    @jaxtyped
    @typechecker
    def get_1_dom(self, idx: int) -> Float[torch.Tensor, 'c h w']:
        all_domains = self.get_3_dom(idx)
        return all_domains[self.params.targeted_domain]

    def __getitem__(self, idx: int):
        if self.params.only_image:
            return self.get_1_dom(idx) + ((idx, ) if self.params.return_indice else ())
        else:
            return self.get_3_dom(idx) + ((idx, ) if self.params.return_indice else ())

    def __len__(self) -> int:
        return len(self.imgs)


class CelebADataModule(CustomDataModule):
    def _fetch_base_dataset(self) -> Tuple[data.Dataset, data.Dataset, data.Dataset]:
        """
        Return train, valid and test dataset
        """
        params: CelebAParams = self.p.data_params
        global_dataset = CelebAHQDataset(params)

        train_dataset, valid_dataset, test_dataset = self.split_dataset(global_dataset)

        return train_dataset, valid_dataset, test_dataset


if __name__ == '__main__':
    dataset = CelebAHQDataset(
        params=CelebAParams(),
    )
    print(f'{len(dataset)=}')

    x, y, z = dataset[0]
    print(x.shape)
    print(y.shape)
    print(z.shape)

    print('end')
