import functools
from typing import List, Optional, NamedTuple, Literal
from argparse import Namespace

import torch
import numpy as np
import pydicom
import torchio as tio

from sybil.datasets.utils import order_slices, VOXEL_SPACING
from sybil.utils.loading import get_sample_loader


class Meta(NamedTuple):
    paths: list
    thickness: float
    pixel_spacing: list
    manufacturer: str
    slice_positions: list
    voxel_spacing: torch.Tensor


class Label(NamedTuple):
    y: int
    y_seq: np.ndarray
    y_mask: np.ndarray
    censor_time: int


class Serie:
    def __init__(
        self,
        dicoms: List[str],
        voxel_spacing: Optional[List[float]] = None,
        label: Optional[int] = None,
        censor_time: Optional[int] = None,
        file_type: Literal["png", "dicom"] = "dicom",
        split: Literal["train", "dev", "test"] = "test",
    ):
        """Initialize a Serie.

        Parameters
        ----------
        `dicoms` : List[str]
            [description]
        `voxel_spacing`: Optional[List[float]], optional
            The voxel spacing associated with input CT
            as (row spacing, col spacing, slice thickness)
        `label` : Optional[int], optional
            Whether the patient associated with this serie
            has or ever developped cancer.
        `censor_time` : Optional[int]
            Number of years until cancer diagnostic.
            If less than 1 year, should be 0.
        `file_type`: Literal['png', 'dicom']
            File type of CT slices
        `split`: Literal['train', 'dev', 'test']
            Dataset split into which the serie falls into.
            Assumed to be test by default
        """
        if label is not None and censor_time is None:
            raise ValueError("censor_time should also provided with label.")
        if file_type == "png" and voxel_spacing is None:
            raise ValueError("voxel_spacing should be provided for PNG files.")

        self._censor_time = censor_time
        self._label = label
        args = self._load_args(file_type)
        self._args = args
        self._loader = get_sample_loader(split, args)
        self._meta = self._load_metadata(dicoms, voxel_spacing, file_type)
        self._check_valid(args)
        self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING)
        self.padding_transform = tio.transforms.CropOrPad(
            target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0
        )

    def has_label(self) -> bool:
        """Check if there is a label associated with this serie.

        Returns
        -------
        bool
            [description]
        """
        return self._label is not None

    def get_label(self, max_followup: int = 6) -> Label:
        """Get the label for this Serie.

        Parameters
        ----------
        max_followup : int, optional
            [description], by default 6

        Returns
        -------
        Tuple[bool, np.array, np.array, int]
            [description]

        Raises
        ------
        ValueError
            [description]

        """
        if not self.has_label():
            raise ValueError("No label in this serie.")

        # First convert months to years
        year_to_cancer = self._censor_time  # type: ignore

        y_seq = np.zeros(max_followup, dtype=np.float64)
        y = int((year_to_cancer < max_followup) and self._label)  # type: ignore
        if y:
            y_seq[year_to_cancer:] = 1
        else:
            year_to_cancer = min(year_to_cancer, max_followup - 1)

        y_mask = np.array(
            [1] * (year_to_cancer + 1) + [0] * (max_followup - (year_to_cancer + 1)),
            dtype=np.float64,
        )
        return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer)

    def get_raw_images(self) -> List[np.ndarray]:
        """
        Load raw images from serie

        Returns
        -------
        List[np.ndarray]
            List of CT slices of shape (1, C, H, W)
        """

        loader = get_sample_loader("test", self._args, apply_augmentations=False)
        input_dicts = [loader.get_image(path) for path in self._meta.paths]
        images = [i["input"] for i in input_dicts]
        return images

    @functools.lru_cache
    def get_volume(self) -> torch.Tensor:
        """
        Load loaded 3D CT volume

        Returns
        -------
        torch.Tensor
            CT volume of shape (1, C, N, H, W)
        """

        input_dicts = [
            self._loader.get_image(path) for path in self._meta.paths
        ]

        x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)

        # Convert from (T, C, H, W) to (C, T, H, W)
        x = x.permute(1, 0, 2, 3)

        x = tio.ScalarImage(
            affine=torch.diag(self._meta.voxel_spacing),
            tensor=x.permute(0, 2, 3, 1),
        )
        x = self.resample_transform(x)
        x = self.padding_transform(x)
        x = x.data.permute(0, 3, 1, 2)
        x.unsqueeze_(0)
        return x

    def _load_metadata(self, paths, voxel_spacing, file_type):
        """Extract metadata from dicom files efficiently

        Parameters
        ----------
        `paths` : List[str]
            List of paths to dicom files
        `voxel_spacing`: Optional[List[float]], optional
            The voxel spacing associated with input CT
            as (row spacing, col spacing, slice thickness)
        `file_type` : Literal['png', 'dicom']
            File type of CT slices

        Returns
        -------
        Tuple[list]
            slice_positions: list of indices for dicoms along z-axis
        """
        if file_type == "dicom":
            slice_positions = []
            processed_paths = []
            for path in paths:
                dcm = pydicom.dcmread(path, stop_before_pixels=True)
                processed_paths.append(path)
                slice_positions.append(float(dcm.ImagePositionPatient[-1]))

            processed_paths, slice_positions = order_slices(
                processed_paths, slice_positions
            )

            thickness = float(dcm.SliceThickness)
            pixel_spacing = list(map(float, dcm.PixelSpacing))
            manufacturer = dcm.Manufacturer if hasattr(dcm, "Manufacturer") else ""
            voxel_spacing = torch.tensor(pixel_spacing + [thickness, 1])
        elif file_type == "png":
            processed_paths = paths
            slice_positions = list(range(len(paths)))
            thickness = voxel_spacing[-1] if voxel_spacing is not None else None
            pixel_spacing = []
            manufacturer = ""
            voxel_spacing = (
                torch.tensor(voxel_spacing + [1]) if voxel_spacing is not None else None
            )

        meta = Meta(
            paths=processed_paths,
            thickness=thickness,
            pixel_spacing=pixel_spacing,
            manufacturer=manufacturer,
            slice_positions=slice_positions,
            voxel_spacing=voxel_spacing,
        )
        return meta

    def _load_args(self, file_type):
        """
        Load default args required for a single Serie volume

        Parameters
        ----------
        file_type : Literal['png', 'dicom']
            File type of CT slices

        Returns
        -------
        Namespace
            args with preset values
        """
        args = Namespace(
            **{
                "img_size": [256, 256],
                "img_mean": [128.1722],
                "img_std": [87.1849],
                "num_images": 200,
                "img_file_type": file_type,
                "num_chan": 3,
                "cache_path": None,
                "use_annotations": False,
                "fix_seed_for_multi_image_augmentations": True,
                "slice_thickness_filter": 5,
            }
        )
        return args

    def _check_valid(self, args):
        """
        Check if serie is acceptable:

        Parameters
        ----------
        `args` : Namespace
            manually set args used to develop model

        Raises
        ------
        ValueError if:
            - serie doesn't have a label, OR
            - slice thickness is too big
        """
        if self._meta.thickness is None:
            raise ValueError("slice thickness not found")
        if self._meta.thickness > args.slice_thickness_filter:
            raise ValueError(
                f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}."
            )
        if self._meta.voxel_spacing is None:
            raise ValueError("voxel spacing either not set or not found in DICOM")
