from typing import List

import torch
from tensordict import TensorDict

from calibration_schemes.AbstractCalibration import Calibration
from models.qr_models.PredictionIntervalModel import PredictionIntervals
from models.model_utils import ModelPrediction, UncertaintySets, PointPrediction


class PointPredictionCalibration(Calibration):

    def __init__(self, alpha: float):
        super().__init__(alpha)
        self.Q = None

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal, cal_prediction: ModelPrediction, **kwargs):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal, cal_prediction, **kwargs)

        if not isinstance(cal_prediction, PointPrediction):
            raise Exception("cal_prediction must be of type PointPrediction for dummy calibration")
        y = y_cal
        if isinstance(y, TensorDict):
            y = y['y']
        errors = (cal_prediction.predictions - y).abs()
        q_level = min(1-self.alpha + 1 / (len(y)+1), 1)
        self.Q = errors.quantile(q=q_level).item()

    def construct_calibrated_uncertainty_sets(self, x_test: torch.Tensor, test_prediction: ModelPrediction,
                                              **kwargs) -> PredictionIntervals:
        if not isinstance(test_prediction, PointPrediction):
            raise Exception("test_prediction must be of type PointPrediction for dummy calibration")
        prediction = test_prediction.predictions
        lower = prediction - self.Q
        upper = prediction + self.Q
        interval = torch.cat([lower.unsqueeze(-1), upper.unsqueeze(-1)], dim=-1)
        return PredictionIntervals(interval)

    @property
    def name(self):
        return "residual"

    def compute_scores(self, x, y, cal_prediction: ModelPrediction):
        return torch.zeros(len(y)).to(y.device)

    def compute_uncertainty_set_from_prediction_and_threshold(self, test_prediction: ModelPrediction,
                                                              threshold) -> UncertaintySets:
        if not isinstance(test_prediction, PointPrediction):
            raise Exception("test_prediction must be of type PointPrediction for dummy calibration")
        prediction = test_prediction.predictions
        lower = prediction - threshold
        upper = prediction + threshold
        interval = torch.cat([lower.unsqueeze(-1), upper.unsqueeze(-1)], dim=-1)
        return PredictionIntervals(interval)

    def jackknife_plus_construct_uncertainty_set_from_scores(self, x_cal, y_cal, z_cal, deleted_cal, scores_cal, x_test, test_prediction: List[ModelPrediction], **kwargs) -> UncertaintySets:
        raise NotImplementedError()
