from abc import ABC, abstractmethod
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F
import json

from .base import ClassifierBase
from utils.iterp import torch_interp
from pathlib import Path
import functools

import torchio as tio

import sys
sys.path.append("src/classifiers/src_sybil")
from sybil.models.sybil import SybilNet

SYBIL_MEAN = 128.1722
SYBIL_STD = 87.1849

INPAINTER_SPACING = (1.0, 1.0, 2.0)
TRUE_SYBIL_SPACING = (1.40625, 1.40625, 2.5)
TRUE_SYBIL_INPUT_SHAPE = (256, 256, 200)

SYBIL_YEAR1_60P = None # Placeholder to compute

SYBIL_WINDOW_CENTER = -600
SYBIL_WINDOW_WIDTH = 1500
SYBIL_MIN_HU = SYBIL_WINDOW_CENTER - SYBIL_WINDOW_WIDTH / 2 # -1350
SYBIL_MAX_HU = SYBIL_WINDOW_CENTER + SYBIL_WINDOW_WIDTH / 2 # 150


@functools.lru_cache(maxsize=1)
def _get_transform():
    resample = tio.transforms.Resample(TRUE_SYBIL_SPACING)
    crop_pad = tio.transforms.CropOrPad(TRUE_SYBIL_INPUT_SHAPE, padding_mode=0)
    return resample, crop_pad


def _resample(x: torch.Tensor, current_spacing, target_spacing):
    # x: (B, C, W, H, D)
    N, C, W, H, D = x.shape
    shape = np.array([W, H, D])
    current_spacing = np.array(current_spacing, dtype=np.float32)
    target_spacing = np.array(target_spacing, dtype=np.float32)
    new_shape = (current_spacing * shape / target_spacing).astype(np.int64)
    x = F.interpolate(x, size=tuple(new_shape), mode='trilinear', align_corners=True)
    return x


def _croppad(x: torch.Tensor, target_shape):
    # x: (B, C, W, H, D)
    N, C, W, H, D = x.shape
    target_shape = np.array(target_shape)
    shape = np.array([W, H, D])
    diff = target_shape - shape
    pad = np.maximum(diff, 0)
    crop = np.maximum(-diff, 0)

    pl = pad // 2
    pr = pad - pl
    cl = crop // 2
    cr = crop - cl

    padding = (pl[2], pr[2], pl[1], pr[1], pl[0], pr[0])
    x = F.pad(x, padding, mode='constant', value=0)
    slices = (slice(cl[0], x.shape[2] - cr[0]), slice(cl[1], x.shape[3] - cr[1]), slice(cl[2], x.shape[4] - cr[2]))
    x = x[:, :, slices[0], slices[1], slices[2]]
    return x



def apply_sybil_transforms(x: torch.Tensor, x_min_hu: int = -1025, x_max_hu: int = 3071):
    # Input: [0, 1] range (N, C, W, H, D)

    # Windowing
    x = x * (x_max_hu - x_min_hu) + x_min_hu # Convert to HU
    x = (x - SYBIL_WINDOW_CENTER) / SYBIL_WINDOW_WIDTH
    x = torch.clamp(x, -1/2, 1/2) # [-1/2, 1/2]
    x = (2*x + 1) / 2 # [0, 1]

    # Normalize
    x = x * 255 # Convert to [0, 255] range
    prev_dtype = x.dtype
    x = x.to(torch.uint8).to(prev_dtype) # Sybil expects quantized inputs
    x = (x - SYBIL_MEAN) / SYBIL_STD

    # Resample and CropPad
    x = _resample(x, current_spacing=INPAINTER_SPACING, target_spacing=TRUE_SYBIL_SPACING)
    resampled_shape = x.shape
    x = _croppad(x, target_shape=TRUE_SYBIL_INPUT_SHAPE)

    # Change to shape [N, C, T=D, W, H]
    x = x.permute(0, 1, 4, 2, 3)

    if x.shape[1] == 1:
        x = x.repeat(1, 3, 1, 1, 1)

    return x, resampled_shape


class SimpleClassifierGroupTorch(nn.Module):
    """
    A class to represent a calibrator for prediction models.
    Behavior and coefficients are taken from the sklearn.calibration.CalibratedClassifierCV class.
    Make a custom class to avoid sklearn versioning issues.

    Adapted from Sybil to work with PyTorch.
    """

    def __init__(self, calibrators):
        super().__init__()
        self.calibrators = calibrators

    def predict_proba(self, X, expand=False):
        """
        Predict class probabilities for X.

        Parameters
        ----------
        X : array-like of shape (n_probabilities,)
            The input probabilities to recalibrate.
        expand : bool, default=False
            Whether to return the probabilities for each class separately.
            This is intended for binary classification which can be done in 1D,
            expand=True will return a 2D array with shape (n_probabilities, 2).

        Returns
        -------
        proba : ndarray of shape (n_samples, n_classes)
            The class probabilities of the input samples. Classes are ordered by
            lexicographic order.
        """
        proba = torch.stack([calibrator.transform(X) for calibrator in self.calibrators])
        pos_prob = torch.mean(proba, dim=0)

        if expand and len(self.calibrators) == 1:
            return torch.tensor([1.-pos_prob, pos_prob])
        else:
            return pos_prob

    @classmethod
    def from_json(cls, json_list):
        return cls([SimpleIsotonicRegressorTorch.from_json(json_dict) for json_dict in json_list])

    @classmethod
    def from_json_grouped(cls, json_path):
        """
        We store calibrators in a diction of {year (str): [calibrators]}.
        This is a convenience method to load that dictionary from a file path.
        """
        json_dict = json.load(open(json_path, "r"))
        output_dict = {key: cls.from_json(json_list) for key, json_list in json_dict.items()}
        return output_dict


class SimpleIsotonicRegressorTorch(nn.Module):
    def __init__(self, coef, intercept, x0, y0, x_min=-torch.inf, x_max=torch.inf):
        """
        Adapted from Sybil to work with PyTorch.
        """
        super().__init__()
        self.coef = coef
        self.intercept = intercept
        self.x0 = x0
        self.y0 = y0
        self.x_min = x_min
        self.x_max = x_max

    def transform(self, X):
        T = X @ self.coef + self.intercept
        T = torch.clip(T, self.x_min, self.x_max)
        return torch_interp(T, self.x0, self.y0)

    @classmethod
    def from_json(cls, json_dict):
        return cls(
            torch.tensor(json_dict["coef"]).cuda(),
            torch.tensor(json_dict["intercept"]).cuda(),
            torch.tensor(json_dict["x0"]).cuda(),
            torch.tensor(json_dict["y0"]).cuda(),
            json_dict["x_min"],
            json_dict["x_max"]
        )

    def __repr__(self):
        return f"SimpleIsotonicRegressorTorch(x={self.x0}, y={self.y0})"



class Sybil(ClassifierBase):

    def __init__(
            self,
            path_ckpt: str,
            calibrator_path: str,
            x_min_hu: int = -1025,
            x_max_hu: int = 3071,
            num_years: int = 6,
            *args, **kwargs
        ):
        super().__init__(*args, **kwargs)

        path_ckpt = Path(path_ckpt)
        calibrator_path = Path(calibrator_path)
        self.x_min_hu = x_min_hu
        self.x_max_hu = x_max_hu
        self.num_years = num_years

        self.x_orig = None
        self.mask_locations = None

        assert path_ckpt.exists(), f"ckpt_dir {path_ckpt} does not exist"
        assert calibrator_path.exists(), f"calibrator_path {calibrator_path} does not exist"

        self.model = self.load_model(path_ckpt)
        self.calibrator = SimpleClassifierGroupTorch.from_json_grouped(calibrator_path)

        self.eval()


    def load_model(self, path: Path):
        ckpt_model = torch.load(path, map_location="cpu")
        args = ckpt_model.prob_of_failure_layer.args # Extract args from checkpoint
        model = SybilNet(args)
        model.load_state_dict(ckpt_model.state_dict())  # type: ignore
        return model


    def calibrate(self, scores):
        calibrated_scores = []
        for YEAR in range(scores.shape[1]):
            probs = scores[:, YEAR].reshape(-1, 1)
            probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
            calibrated_scores.append(probs)

        return torch.stack(calibrated_scores, axis=1)


    def forward_all_years(self, x):
        """Get the sigmoid of the base hazard and calibrated risk for all years"""
        x, _ = apply_sybil_transforms(x, self.x_min_hu, self.x_max_hu)
        y = self.model(x)
        scores = self.calibrate(y["logit"].sigmoid())
        base_hazard = y["base_hazard"].sigmoid()

        y = torch.concat([base_hazard, scores], dim=1)
        y = y.clip(0, 1)

        return y


    def forward(self, x):
        """Get the year 1 risk as the forward output"""
        scores = self.forward_all_years(x)
        return scores[:, 1]


    def pred_prob(self, x):
        x = self.precondition(x, None)
        y = self.forward(x)
        # This is a hack to make binary classification compatible with the rest of the code
        return torch.stack([y, y], dim=1).float()
    

    def pred_label(self, x):
        x = self.pred_prob(x)
        # 0 is low risk, 1 is high risk
        return (x[:, 0] > SYBIL_YEAR1_60P).long()


    def set_precondition_imgs(self, x, mask_locations):
        self.x_orig = x
        self.mask_locations = mask_locations


    def precondition(self, inpaint, y):
        # inpaint: [0, 1] range

        if self.x_orig is not None:
            inpainted = []
            for i in range(inpaint.shape[0]):
                orig = self.x_orig[i].clone()
                c, w, h, d = orig.shape

                x_s, y_s, z_s = self.mask_locations[0]

                orig[:, x_s:(x_s+64), y_s:(y_s+64), z_s:(z_s+64)] = 0
                
                x_top = x_s
                x_bottom = w - (x_s + 64)
                y_top = y_s
                y_bottom = h - (y_s + 64)
                z_top = z_s
                z_bottom = d - (z_s + 64)

                pad = (z_top, z_bottom, y_top, y_bottom, x_top, x_bottom)
                x_pad = nn.functional.pad(inpaint[i], pad)
                # import ipdb; ipdb.set_trace()
                inpainted.append(orig + x_pad)

            x = torch.stack(inpainted).float()
        else:
            x = inpaint

        # Handle single channel images
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1, 1)

        # N, C', W, H, D
        return x


    def reset_precondition(self):
        self.set_precondition_imgs(None, None)


    def load_ckpt(self, *args, **kwargs):
        pass
s
