from typing import List

import numpy as np
import torch

from calibration_schemes.AbstractCalibration import Calibration
from imputation_methods.ImputationMethod import ImputationMethod
from models.model_utils import UncertaintySets, ModelPrediction
from utils.utils import get_seed, set_seeds


class CalibrationByImputation(Calibration):

    def jackknife_plus_construct_uncertainty_set_from_scores(self, x_cal, y_cal, z_cal, deleted_cal, scores_cal, x_test,
                                                             model_predictions: List[ModelPrediction],
                                                             **kwargs) -> UncertaintySets:
        raise NotImplementedError()

    def __init__(self, imputation_method: ImputationMethod, base_calibration: Calibration, alpha: float):
        super().__init__(alpha)
        self.imputation_method = imputation_method
        self.base_calibration = base_calibration

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000,
            batch_size=64, n_wait=20, **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=epochs,
            batch_size=batch_size, n_wait=n_wait, **kwargs)
        self.imputation_method.fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs,
                                   batch_size, n_wait, **kwargs)
        self.base_calibration.fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs,
                                  batch_size, n_wait, **kwargs)

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal, cal_prediction: ModelPrediction, **kwargs):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal, cal_prediction, **kwargs)
        seed = get_seed()
        set_seeds(42)
        rnd_idx = np.random.permutation(len(x_cal))
        set_seeds(seed)
        imputation_idx = rnd_idx[:2 * len(rnd_idx) // 3]
        calibration_idx = rnd_idx[2 * len(rnd_idx) // 3:]
        self.imputation_method.calibrate(x_cal[imputation_idx], y_cal[imputation_idx], z_cal[imputation_idx],
                                         deleted_cal[imputation_idx])
        x_cal, y_cal, z_cal, deleted_cal = x_cal[calibration_idx], y_cal[calibration_idx], z_cal[calibration_idx], \
                                           deleted_cal[calibration_idx]
        imputed_y_cal = y_cal.clone().squeeze()
        cal_prediction = cal_prediction[calibration_idx]

        if deleted_cal.int().sum().item() >= 1:
            imputed_y_cal[deleted_cal] = self.imputation_method.predict(x_cal[deleted_cal], z_cal[deleted_cal]).type(
                imputed_y_cal.dtype).squeeze()
        new_deleted_cal = torch.zeros(len(x_cal)).to(x_cal.device).bool()
        self.base_calibration.calibrate(x_cal, imputed_y_cal, z_cal, deleted_cal=new_deleted_cal,
                                        cal_prediction=cal_prediction)

    def construct_calibrated_uncertainty_sets(self, x_test: torch.Tensor,
                                              test_uncalibrated_intervals: ModelPrediction,
                                              **kwargs) -> UncertaintySets:
        return self.base_calibration.construct_calibrated_uncertainty_sets(x_test, test_uncalibrated_intervals)

    @property
    def name(self):
        return f"{self.imputation_method.name}_imputation_{self.base_calibration.name}_calibration"

    def compute_performance(self, x_test, y, z_test, full_y_test, deleted_test, test_model_prediction, **kwargs) -> dict:
        return self.imputation_method.compute_performance(x_test, full_y_test, z_test, full_y_test, deleted_test,
                                                          test_model_prediction)

    def compute_scores(self, x, y, cal_prediction: ModelPrediction):
        raise NotImplementedError()

    def compute_uncertainty_set_from_prediction_and_threshold(self, test_prediction: ModelPrediction,
                                                              threshold) -> UncertaintySets:
        raise NotImplementedError()
