
from torch.utils.data import Dataset
from abc import ABC

import cc3d
import numpy as np
import torch
from torch import Tensor
from utils.segmentations import find_region_center, enlarge_segmentation_torch

EPS = 1e-6

import logging
log = logging.getLogger(__name__)


class CTRegionExtractor:
    def __init__(self, max_nodules: int, max_backgrounds: int, enlarge_mask: int, dust_threshold: int, max_nodule_volume: int = 64 ** 3):
        self.max_nodules = max_nodules
        self.max_backgrounds = max_backgrounds
        self.enlarge_mask = enlarge_mask
        self.dust_threshold = dust_threshold
        self.max_nodule_volume = max_nodule_volume


    def get_region_segments(self, region_seg_mask: Tensor) -> tuple[Tensor, int]|tuple[None, None]:
        """
        Convert the binary region segmentation mask into a multi-class mask.
        The input mask is expected to be of shape [W, H, D] with values 0 or 1.
        """
        if (region_seg_mask > 1).any():
            log.info("Region segmentation mask is already a multi-class mask")
            return region_seg_mask, int(torch.max(region_seg_mask).item())

        mask = enlarge_segmentation_torch(region_seg_mask, self.enlarge_mask)
        mask_np = mask.detach().cpu().numpy()
        mask_np = cc3d.dust(mask_np, self.dust_threshold)
        mc_mask_np, N = cc3d.connected_components(mask_np, return_N=True)
        mc_mask_np = mc_mask_np.astype(np.uint8)
        if N == 0:
            log.warning("No regions found in the segmentation")
            return None, None
        log.info(f"Found {N} regions in the segmentation")
        if N > self.max_nodules:
            log.warning(f"Truncating regions to {self.max_nodules}")
            N = self.max_nodules

        volume_per_idx = [(idx, np.sum(mc_mask_np == idx)) for idx in range(1, N+1) ]
        volume_per_idx.sort(reverse = True, key = lambda x: x[1])
        valid_nodule_idxs = [
            idx for (idx, vol) in volume_per_idx if (vol > EPS) and (vol < self.max_nodule_volume)
        ]
        if len(valid_nodule_idxs) == 0:
            log.warning("No nodule fitting criterion were found. Ignoring max_nodule_volume condition.")
            valid_nodule_idxs = [ idx for (idx, vol) in volume_per_idx if (vol > EPS) ]
        index_map = { idx: i+1 for i, idx in enumerate(valid_nodule_idxs) }
        mc_mask_np2 = np.zeros_like(mc_mask_np, dtype=np.uint8)
        for i in range(1, N+1):
            if i in index_map:
                mc_mask_np2[mc_mask_np == i] = index_map[i]
        return torch.from_numpy(mc_mask_np2), len(valid_nodule_idxs)


    def get_lung_region_segments(
            self,
            lung_seg_mask: Tensor,
            nodule_seg_mask: Tensor | None = None,
            region_size: int = 64,
        ):
        # Creates a multiclass mask which effectivly
        # covers the entirety of the lungs except 
        # for the regions that contain nodules.
        segmentation = torch.zeros_like(lung_seg_mask)
        region_centers = []
        half_size = region_size // 2
        cstart, cstop = half_size // 2, half_size + half_size // 2
        inner_slice = tuple(
            [slice(cstart, cstop)] * 3
        ) 

        x_max, y_max, z_max = lung_seg_mask.shape
        i = 1
        for x in range(0, x_max-region_size, half_size):
            for y in range(0, y_max-region_size, half_size):
                for z in range(0, z_max-region_size, half_size):
                    region_slice = (
                        slice(x, x + region_size),
                        slice(y, y + region_size),
                        slice(z, z + region_size),
                    )
                    inner_region_slice = (
                        slice(x + cstart, x + cstop),
                        slice(y + cstart, y + cstop),
                        slice(z + cstart, z + cstop),
                    )
                    if nodule_seg_mask is not None:
                        # Skip regions that contain nodules
                        if (torch.sum(nodule_seg_mask[region_slice]) > EPS):
                            continue


                    region = lung_seg_mask[region_slice]
                    inner_region = region[inner_slice]
                    if torch.mean(inner_region) > 0.95:
                        segmentation[inner_region_slice] = i

                        center = torch.Tensor([
                            x + half_size,
                            y + half_size,
                            z + half_size
                        ]).float()
                        region_centers.append(center)
                        i += 1

        if torch.any(segmentation > 0):
            return segmentation, torch.stack(region_centers)
        else:
            return None, None


    def get_optimal_bg_subset(
                self,
                region_centers: Tensor, 
                lung_mask: Tensor, 
                nodule_centers: Tensor | None = None,
            ) -> list[int]:
            # Finds a subset of the lung regions
            # region_centers: [N, 3]
            # lung_mask: [W, H, D]
            # nodule_centers: [M, 3]
            if len(region_centers) <= self.max_backgrounds:
                return None

            log.info("Finding background subset of size %d", self.max_backgrounds)

            lung_coordinates = lung_mask.nonzero().float()
            if nodule_centers is not None:
                log.info("Using nodule centers")
                n, m = len(nodule_centers), len(region_centers)
                selected_indices = set(range(m, n+m))
                possible_indices = set(range(m))

                all_centers = torch.cat(
                    [region_centers, nodule_centers.to(region_centers.device)], dim=0
                ) 
            else:
                selected_indices = set()
                possible_indices = set(range(len(region_centers)))
                all_centers = region_centers

            all_centers = all_centers.float().to(lung_coordinates.device)
            distance_matrix = torch.cdist(
                lung_coordinates.unsqueeze(0), all_centers.unsqueeze(0), p=1
            ).squeeze(0)
            distance_matrix = (distance_matrix - 16).clip(min=0)
            log.info("Distance matrix shape: %s", distance_matrix.shape)

            # Greedy search
            max_indices = len(selected_indices) + self.max_backgrounds
            while len(selected_indices) < max_indices:
                # Find the center with the minimum distance to the selected centers
                min_i = 0
                min_value = float('inf')
                for i in possible_indices:
                    indicies = list(selected_indices) + [i]
                    max_distance, _ = torch.max(distance_matrix[:, indicies], dim=1)
                    cost = torch.mean(max_distance).item()
                    if cost < min_value:
                        min_value = cost
                        min_i = i

                selected_indices.add(min_i)
                possible_indices.remove(min_i)

            if nodule_centers is not None:
                selected_indices = [i for i in selected_indices if i < len(region_centers)]

            return dict(zip(
                [i+1 for i in selected_indices],
                list(range(1, len(selected_indices) + 1))
            ))