from monai import transforms as mtf
import numpy as np
import nibabel as nib 
import torch

from utils.nib_utils import get_spacing, resample_scan

TARGET_SPACING = np.array([1.0, 1.0, 2.0])
MIN_HU = -1024 - 1 # -1025 for sentinel value of background
MAX_HU = 3070 # Now MAX_HU - MIN_HU = 4096 which is a power of 2

class NLSTMixin:
    """
    Mixin class providing NLST-specific preprocessing and postprocessing methods.
    """
    def preprocess_scan(self, scan: nib.Nifti1Image) -> torch.Tensor:
        spacing = get_spacing(scan.affine)
        data = scan.get_fdata()
        resampled_data = resample_scan(data, spacing, TARGET_SPACING, order=1)
        data = torch.from_numpy(resampled_data).float()
        data = (data - MIN_HU) / (MAX_HU - MIN_HU)
        data = torch.clip(data, 0, 1) 
        return data

    def reverse_transform(self, data: torch.Tensor, original_affine: np.ndarray, as_nifty: bool = False) -> torch.Tensor | nib.Nifti1Image:
        original_spacing = get_spacing(original_affine)
        orig_device = data.device
        data = torch.clip(data, 0, 1)
        data = data * (MAX_HU - MIN_HU) + MIN_HU
        data = resample_scan(data.cpu().numpy(), TARGET_SPACING, original_spacing, order=1)
        if as_nifty:
            return nib.Nifti1Image(data.astype(np.int16), original_affine)
        else:
            return torch.from_numpy(data.astype(np.int16)).to(orig_device)

    def reverse_mask_transform(self, masks: torch.Tensor, original_affine: np.ndarray, as_nifty: bool = False) -> torch.Tensor | nib.Nifti1Image:
        original_spacing = get_spacing(original_affine)
        orig_device = masks.device
        masks = resample_scan(masks.cpu().numpy(), TARGET_SPACING, original_spacing, order=0)
        masks = np.round(masks).astype(np.int16)
        if as_nifty:
            return nib.Nifti1Image(masks, original_affine)
        else:
            return torch.from_numpy(masks).to(orig_device)


