from PIL import Image
import numpy as np
import torch
import os

import imageio
imageio.plugins.freeimage.download()
from . import region_cityscapes_or_tensor

class RegionCityscapesOr(region_cityscapes_or_tensor.RegionCityscapesOr):

    def __init__(self, args, root, datalist, split='train', transform=None, return_spx=False,
                 region_dict="dataloader/init_data/cityscapes/train.dict", mask_region=True, dominant_labeling=False, loading='binary', load_smaller_spx=False):
        super().__init__(args, root, datalist, split, transform, return_spx, region_dict, mask_region, dominant_labeling, loading, load_smaller_spx)
        assert(self.mask_region)
        assert(not self.load_smaller_spx)

        r''' remove_dominant for analysis and visualization
        - only include dominant when saving the pseudo labels
        '''
        if 'eval_save' in args.method:
            self.remove_dominant = False
        else:
            self.remove_dominant = True

    def __getitem__(self, index):
        img_fname, lbl_fname, spx_fname = self.im_idx[index] ### warnning: index => superpixel-wise 로 정의됨

        ''' Load image, label, and superpixel '''
        image = Image.open(img_fname).convert('RGB')
        superpixel = self.open_spx(spx_fname)

        id = lbl_fname.split('/')[-1].split('.')[0]
        target_path = '{}/gtFine/train/{}/{}_gtFine_labelIds.png'.format(self.root, id.split('_')[0], id)
        target_precise = Image.open(target_path)
        target_precise = torch.from_numpy(self.encode_target(target_precise).astype('uint8'))
        target_precise = torch.masked_fill(target_precise, target_precise == 255, 19) ### original label as 19th class
        target_precise = Image.fromarray(target_precise.numpy())

        ''' Resize both images, superpixel map '''
        image, lbls = self.transform(image, [target_precise, superpixel])
        target_precise, superpixel = lbls

        ''' Get actively sampled superpixel ids '''
        preserving_labels = torch.tensor(self.suppix[spx_fname] if spx_fname in self.suppix else [])
        
        trg_index = self.id_to_index[id]
        target = self.multi_hot_cls[trg_index] ### [nseg x (num_classes + 1)]

        r''' remove dominant label within preserving_labels '''
        if self.remove_dominant:
            preserving_spx_ncls = target[preserving_labels].sum(dim=1) ### [nselected]
            is_multi = torch.logical_not(preserving_spx_ncls == 1) ### [nselected]
            preserving_labels = preserving_labels[is_multi]

        ''' Filter unselected superpixels & ignored regions '''
        sp_mask = torch.isin(superpixel, preserving_labels)
        # target_precise = torch.masked_fill(target_precise, torch.logical_not(sp_mask), 255)

        sample = {'images': image,
                  'labels': target_precise,
                  'target': target,
                  'spx': superpixel,
                  'spmask': sp_mask,
                  'fnames': self.im_idx[index]}

        return sample