from typing import List

import torch
from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.CalibrationByImputation import CalibrationByImputation
from calibration_schemes.DummyCalibration import DummyCalibration
from calibration_schemes.TwoStagedConformalPrediction import TwoStagedCalibration
from models.model_utils import ModelPrediction, UncertaintySets


class TriplyRobustCalibration(Calibration):

    def __init__(self, alpha: float, calibration1: Calibration,
                 calibration2: Calibration):
        super().__init__(alpha)
        self.calibration1 = calibration1
        self.calibration2 = calibration2
        self.dummy_calibration = DummyCalibration(alpha)
        self.calibrations = [self.dummy_calibration, self.calibration1, self.calibration2]

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
            epochs=1000, batch_size=64, n_wait=20, **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
                    epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)
        for calibration in self.calibrations:
            calibration.fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
                            epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal, cal_prediction: ModelPrediction, **kwargs):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal, cal_prediction, **kwargs)
        for calibration in self.calibrations:
            calibration.calibrate(x_cal, y_cal, z_cal, deleted_cal, cal_prediction, **kwargs)

    def construct_calibrated_uncertainty_sets(self, x_test: torch.Tensor,
                                              test_prediction: ModelPrediction, **kwargs) -> UncertaintySets:
        total_set = None
        for calibration in self.calibrations:
            curr_set = calibration.construct_calibrated_uncertainty_sets(x_test, test_prediction, **kwargs)
            if total_set is None:
                total_set = curr_set
            else:
                total_set = total_set.union(curr_set)
        return total_set

    def compute_scores(self, x, y, cal_prediction: ModelPrediction):
        raise NotImplementedError("not implemented yet")

    def compute_uncertainty_set_from_prediction_and_threshold(self, test_prediction: ModelPrediction,
                                                              threshold) -> UncertaintySets:
        raise NotImplementedError("not implemented yet")

    @property
    def name(self):
        return f"triply_robust_{self.calibration1.name}_{self.calibration2.name}"

    def jackknife_plus_construct_uncertainty_set_from_scores(self, x_cal, y_cal, z_cal, deleted_cal, scores_cal, x_test,
                                                             model_predictions: List[ModelPrediction],
                                                             **kwargs) -> UncertaintySets:
        raise NotImplementedError("not implemented yet")
