import json
import nibabel as nib
import pandas as pd
import torch
import json

from pathlib import Path
from .base import DatasetWithRegionsInAnOrgan

from utils.nib_utils import get_spacing
from utils.segmentations import load_nib_segmentation, load_nib_multilabel_segmentation

from utils.nib_utils import get_spacing
from .nlst_mixin import NLSTMixin

import logging

log = logging.getLogger(__name__)


class LUNA25DatasetWithSegmentations(NLSTMixin, DatasetWithRegionsInAnOrgan):
    def __init__(
            self, 
            name: str,
            split: str,
            path_data: str, 
            ct_root_dir: str,
            nodule_segmentation_dir: str,
            ts_segmentation_dir: str,
            segmentation_type: str,
            path_labels: str, 
            path_predictions: str, 
            n_samples: int, 
            filter_id: int,
            n_skip: int = 0):
        super().__init__()
        assert split in ['train', 'validation']
        self.name = name
        self.split = split
        self.path_data = path_data
        self.ct_root_dir = Path(ct_root_dir)
        self.nodule_segmentation_dir = Path(nodule_segmentation_dir)
        self.ts_segmentation_dir = Path(ts_segmentation_dir)
        self.n_samples = n_samples
        self.filter_id = filter_id

        self.data = pd.read_csv(path_data)
        self.original_data = self.data.copy()
        # Remove duplicates resulting from multiple nodules in the same scan
        self.data = self.data.drop_duplicates(subset=["PatientIDWithDate"])
        self.data = self.data.reset_index(drop=True)

        if segmentation_type == "nnunetr":
            log.info("Using nnUNetR segmentations")
            self.suffix = "_nnunet_seg_clean.nii.gz"
        elif segmentation_type == "simple":
            log.info("Using simple segmentations")
            self.suffix = "_seg_clean.nii.gz"
        else:
            raise ValueError(f"Unknown segmentation type: {segmentation_type}")

        # Filter out scans that do not have corresponding segmentations
        nodule_segmentation_files = list(self.nodule_segmentation_dir.glob("*.nii.gz"))
        valid_series_uids = {
            f.name.removesuffix(self.suffix)
            for f in nodule_segmentation_files
        }
        self.data = self.data[self.data["SeriesInstanceUID"].isin(valid_series_uids)]
        self.data = self.data.reset_index(drop=True)

        self.predictions = pd.read_csv(path_predictions, index_col = "idx") if path_predictions else None
        self.length = min(len(self.data) - n_skip, self.n_samples)
        self.map_index = lambda x: x + n_skip

        # self.filter_data(filter_id, n_skip)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        index = self.map_index(index)

        ct_rel_path = self.data.loc[index, "path"]
        ct_path = self.ct_root_dir / ct_rel_path
        scan = nib.load(ct_path)
        volume = self.preprocess_scan(scan)

        seg_path = self.nodule_segmentation_dir / f"{self.data.loc[index, 'SeriesInstanceUID']}{self.suffix}"
        nodule_seg = load_nib_segmentation(seg_path, volume.shape)

        nodule_labels_path = self.nodule_segmentation_dir / f"{self.data.loc[index, 'SeriesInstanceUID']}_nodule_labels.json"
        with open(nodule_labels_path, 'r') as f:
            nodule_labels: dict[int, int] = json.load(f)
            nodule_labels = {int(k): v for k, v in nodule_labels.items()}

        nodule_ids = list(nodule_labels.keys())
        assert 0 not in nodule_ids, "Nodule IDs should not contain zero at this point, they are 1-based in LUNA25"

        ts_patient_dir = self.ts_segmentation_dir / f"{self.data.loc[index, 'PatientID']}"
        assert ts_patient_dir.exists(), f"TS segmentation directory {ts_patient_dir} does not exist!"

        # Get the same year segmentation of the scan
        ts_segmentation_files = list(ts_patient_dir.glob("*.nii.gz"))
        matching_ts_file = None
        for f in ts_segmentation_files:
            if str(self.data.loc[index, "StudyDate"]) in f.name:
                matching_ts_file = f
                break
        assert matching_ts_file is not None, f"No matching TS segmentation file found for patient {self.data.loc[index, 'PatientID']} and date {self.data.loc[index, 'StudyDate']}"
        ts_seg = load_nib_multilabel_segmentation(matching_ts_file, volume.shape)
        ## Total segmentator classes for parts of lungs are 10, 11, 12, 13, 14
        lungs_seg = ((ts_seg >= 10) & (ts_seg <= 14)) * 1.0

        nodule_ids.sort()
        nodule_labels = [nodule_labels[nid] for nid in nodule_ids]
        nodule_ids = [nid - 1 if nid > 0 else nid for nid in nodule_ids]  # Convert to zero-based indexing, keep negative IDs the same for validation

        metadata = {
            "index": index,
            "image_path": str(ct_path),
            "filter_id": self.filter_id,
            "patient_id": self.data.loc[index, "PatientID"],
            "series_id": self.data.loc[index, "SeriesInstanceUID"],
            "original_spacing": get_spacing(scan.affine),
            "nodule_ids": nodule_ids,
            "nodule_labels": nodule_labels,
            "original_affine": scan.affine,
        }

        # TODO: Possibly add predictions
        pred_label = self.filter_id
        true_label = 1 # They have sometimes even multiple nodules

        return index, volume.unsqueeze(0), pred_label, true_label, lungs_seg, nodule_seg, metadata
