from functools import partial
import torch
import torch.nn.functional as F
from dataclasses import dataclass, asdict
from torchvision.transforms import transforms, InterpolationMode
from typing import Tuple

from data.BlenderDataset.Generate_LMDB_dataset import return_factory

root = 'CelebAMask-HQ'


@dataclass
class CfgDataset:
    image_height: int = 512
    image_width: int = 512

    initial_img_size   : int = 1024
    initial_sketch_size: int = 1024
    initial_mask_size  : int = 512
    initials_dims: Tuple[int] = (3, 1, 19)
    segmentation_fusion: bool = True

    index_data    : int = 0  # index of what is fetch in the dataset [0: image, 1: sketch, 2: mask]
    list_dim: list[int] = return_factory([3, 1, 19])
    image_channels: int = None  # is set up automatically

    dataset_root_folder: str = f'{root}'
    shuffle: bool = True
    return_params: bool = False

    # -1 for not doing it
    reduced_size_train: int = 100
    reduced_size_valid: int = 100
    reduced_size_test : int = 100

    train_proportion: float = 0.60
    valid_proportion: float = 0.20
    test_proportion : float = 0.20


cfg_dataset = CfgDataset()
cfg_dataset.image_channels = cfg_dataset.list_dim[cfg_dataset.index_data]
params = asdict(cfg_dataset) | {
    'batch_size_train': 10,
    'batch_size_valid': 10,
    'batch_size_test': 10,
}


from datasets.dataloading import get_celebAHQ3
dataset = get_celebAHQ3(params=params).train_ds
dataset = [dataset[i] for i in range(10)]

import numpy as np


from PIL import Image


initial_size = 512
target_size = 128
nb_classes = 19
def tocolor(seg):
    toh = seg2rgb(seg.reshape(1, nb_classes, target_size, target_size)).reshape(3, target_size, target_size)
    return toh

root_save = "datasets/CelebA/dump"

f1 = transforms.Compose([
    toTensorNoNorm(),
    toType(target_type=torch.int64),
    transforms.Resize([target_size, target_size], InterpolationMode.NEAREST),
    Reshape([target_size, target_size]),
    partial(F.one_hot, num_classes=19),
    Permute([2, 0, 1]),
    CollapseMask()
])

f2 = transforms.Compose([
    toTensorNoNorm(),
    toType(target_type=torch.int64),
    Reshape([initial_size, initial_size]),
    partial(F.one_hot, num_classes=19),
    Permute([2, 0, 1]),
    CollapseMask()
])

for i in range(10):
    face, sketch, mask = dataset[i]
    m1 = f1(mask)
    m2 = f2(mask)

    Image.fromarray((face * 255).permute(1, 2, 0).numpy().astype(np.uint8)).save(f"{root_save}/{i}.jpeg")

    # for i_class in range(len(mask)):
    #     i_str = str(i_class) if i_class>=10 else f'0{i_class}'
    #     attr_name = attr_list[i_class]
    #     slice = mask[i_class] * 255
    #     slice = slice.numpy().astype(np.uint8)
    #     im1 = Image.fromarray(slice)
    #     im1.save(f"{root_save}/{i}_{i_str}_{attr_name}.jpeg")

    for i_class in range(len(m1)):
        i_str = str(i_class) if i_class>=10 else f'0{i_class}'
        attr_name = ' '.join(CollapseMask.getCollapsedName(i_class))
        slice = m1[i_class] * 255
        slice = slice.numpy().astype(np.uint8)
        im1 = Image.fromarray(slice)
        im1.save(f"{root_save}/{i}_{i_str}_{attr_name}_m1.jpeg")

    for i_class in range(len(m2)):
        i_str = str(i_class) if i_class>=10 else f'0{i_class}'
        attr_name = ' '.join(CollapseMask.getCollapsedName(i_class))
        slice = m2[i_class] * 255
        slice = slice.numpy().astype(np.uint8)
        im1 = Image.fromarray(slice)
        im1.save(f"{root_save}/{i}_{i_str}_{attr_name}_m2.jpeg")

print('end')
