from typing import Dict, Optional, Tuple

import copy
import numpy as np
from PIL import Image
from scipy.ndimage import label

import torch
import torchvision.transforms.functional as TF

import fourm.utils.data_constants as data_constants
from fourm.utils import generate_uint15_hash
from fourm.data.modality_info import MODALITY_INFO, MODALITY_TRANSFORMS
from fourm.data.modality_transforms import  ImageTransform, get_pil_resample_mode

from label_mappings import REPLICA_MAPPING, SCANNET_PP_MAPPING


class SemsegTransform(ImageTransform):

    def __init__(self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None, 
                 seg_reduce_zero_label=False, seg_ignore_index=255, filter_instance_thresh=None, first_scale_to=None):
        self.scale_factor = scale_factor
        self.shift_idx_by_one = shift_idx_by_one
        self.id_mapping = id_mapping
        self.select_channel = select_channel
        self.seg_reduce_zero_label = seg_reduce_zero_label
        self.seg_ignore_index = seg_ignore_index
        self.filter_instance_thresh = filter_instance_thresh
        self.first_scale_to = first_scale_to

    def map_semseg_values(self, sample):
        sample = np.asarray(sample)
        mapping_fn = lambda x: self.id_mapping.get(x, x)
        sample = np.vectorize(mapping_fn)(sample).astype(np.uint8)
        sample = Image.fromarray(sample, mode='P')
        return sample

    def semseg_to_tensor(self, img):
        # Rescale to scale factor
        if self.scale_factor != 1.0:
            target_height, target_width = int(img.height * self.scale_factor), int(img.width * self.scale_factor)
            img = img.resize((target_width, target_height))
        img = TF.pil_to_tensor(img).to(torch.long).squeeze(0)
        return img
    
    def remove_tiny_instances(self, img):
        img = np.array(img)
        for c in np.unique(img):
            mask = img == c
            labeled_mask, num_features = label(mask)
            
            coverages = []
            # Create individual masks for each cluster
            for j in range(1, num_features + 1):
                cluster_mask = labeled_mask == j
                coverages.append(np.mean(cluster_mask) * 100)
                
            if all([c < self.filter_instance_thresh for c in coverages]):
                if self.seg_reduce_zero_label:
                    img[mask] = self.seg_ignore_index
                else:
                    img[mask] = 0
        img = Image.fromarray(img, mode='P')
        return img

    def load(self, path):
        sample = self.pil_loader(path)
        if self.select_channel is not None:
            sample = sample.split()[self.select_channel]
        return sample

    def preprocess(self, sample):
        sample = sample.convert('P')

        if self.id_mapping is not None:
            sample = self.map_semseg_values(sample)

        if self.shift_idx_by_one:
            sample = np.asarray(sample)
            sample = sample + 1
            sample = Image.fromarray(sample, mode='P')
        
        if self.seg_reduce_zero_label:
            sample = np.array(sample)
            sample[sample == 0] = self.seg_ignore_index
            sample = sample - 1
            sample[sample == self.seg_ignore_index - 1] = self.seg_ignore_index
            sample = Image.fromarray(sample, mode='P')

        return sample

    def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
                      rand_aug_idx: Optional[int], resample_mode: str = None):
        if self.first_scale_to is not None:
            img = img.resize(
                (self.first_scale_to, self.first_scale_to), 
                resample=get_pil_resample_mode('nearest')
            )
            
        img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest')
        img = self.image_hflip(img, flip)
        if self.filter_instance_thresh is not None:
            img = self.remove_tiny_instances(img)
        return img

    def postprocess(self, sample):
        img = self.semseg_to_tensor(sample)
        return img
    

TRANSFER_MODALITY_INFO = copy.deepcopy(MODALITY_INFO)
TRANSFER_MODALITY_INFO.update({
    'semseg_procthor': {
        'type': 'img',
        'num_labels': 40,
        'id': generate_uint15_hash('semseg_procthor'),
        'path': 'semseg'
    },
    'semseg_replica':{
        'type': 'img',
        'num_labels': 45,
        'id': generate_uint15_hash('semseg_replica'),
        'path': 'semseg',
    },
    'semseg_scannet_pp':{
        'type': 'img',
        'num_labels': 62,
        'id': generate_uint15_hash('semseg_scannet_pp'),
        'path': 'semseg',
    },
})

TRANSFER_MODALITY_TRANSFORMS = copy.deepcopy(MODALITY_TRANSFORMS)
TRANSFER_MODALITY_TRANSFORMS.update({
    'semseg_procthor': SemsegTransform(seg_reduce_zero_label=True, seg_ignore_index=data_constants.SEG_IGNORE_INDEX),
    'semseg_replica': SemsegTransform(seg_reduce_zero_label=True, seg_ignore_index=data_constants.SEG_IGNORE_INDEX, id_mapping=REPLICA_MAPPING),
    'semseg_scannet_pp': SemsegTransform(seg_reduce_zero_label=True, seg_ignore_index=data_constants.SEG_IGNORE_INDEX, id_mapping=SCANNET_PP_MAPPING),
})
