import json
import logging
from pathlib import Path

import nibabel as nib
import torch
from utils.segmentations import load_nib_segmentation

from .base import DatasetWithRegionsInAnOrgan
from utils.nib_utils import get_spacing
from .nlst_mixin import NLSTMixin

log = logging.getLogger(__name__)



class ildctWithAnnotationsDataset(NLSTMixin, DatasetWithRegionsInAnOrgan):
    def __init__(
            self, 
            name: str,
            split: str,
            path_data: str, 
            path_tcs_dir: str,
            path_labels: str, 
            path_predictions: str, 
            n_samples: int, 
            filter_id: int,
            n_skip: int = 0
            ):
        super().__init__()
        assert split in ['train', 'validation']
        self.name = name
        self.split = split
        self.path_data = path_data
        self.path_tcs_dir = Path(path_tcs_dir)
        self.n_samples = n_samples
        self.filter_id = filter_id

        self.data = json.load(open(self.path_data))
        # Filtering not supported yet
        # TODO: Compute predictions
        # self.predictions = pd.read_csv(path_predictions, index_col = "idx") if path_predictions else None
        log.warning("Filtering not supported yet")
        self.length = min(len(self.data) - n_skip, self.n_samples)
        self.map_index = lambda x: x + n_skip


    def __len__(self):
        return self.length

    def get_path(self, index):
        return self.data[index]['image']


    def __getitem__(self, index):
        index = self.map_index(index)

        scan_path = Path(self.data[index]['image'])
        series_id = scan_path.name.removesuffix('.nii.gz')
        scan = nib.load(scan_path)
        volume = self.preprocess_scan(scan)
        volume = volume.unsqueeze(0)

        # Load nodule segmentation
        nodule_seg_path = Path(self.data[index]['segmentation'])
        nodule_seg = load_nib_segmentation(nodule_seg_path, volume.shape[-3:])

        # Load lungs segmentation
        lungs_seg_path = self.path_tcs_dir / scan_path.name
        lungs_seg = load_nib_segmentation(lungs_seg_path, volume.shape[-3:])

        # TODO: Possibly add predictions
        pred_label = self.filter_id
        true_label = 1 # They have sometimes even multiple nodules

        metadata = {
            "image_idx": index,
            "image_path": str(scan_path),
            "series_id": series_id,
            "original_affine": scan.affine,
            "original_spacing": get_spacing(scan.affine),
        }

        return index, volume, pred_label, true_label, lungs_seg, nodule_seg, metadata


    def get_guidance_classes(self, config, fabric, batch_labels, batch_pred_labels):
        guide_id = config.exp.guide_id
        return torch.tensor([guide_id] * len(batch_labels))