from typing import List

import numpy as np
import torch
import tqdm

from calibration_schemes.AbstractCalibration import Calibration
from data_utils.data_scaler import DataScaler
from data_utils.datasets.synthetic_dataset_generator import SyntheticDataGenerator
from models.data_mask_estimators.OracleDataMaskerWithDelta import OracleDataMaskerWithDelta
from models.data_mask_estimators.OracleDataMaskerWithDeltaMinMax import OracleDataMaskerWithDeltaMinMax
from models.qr_models.PredictionIntervalModel import PredictionIntervalModel
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.model_utils import ModelPrediction, UncertaintySets
from utils.utils import weighted_quantile, set_seeds, get_seed


class WeightedCalibration(Calibration):
    """
    Implementation of weighted conformal prediction (https://arxiv.org/abs/1904.06019)
    The weights are computed using X,Z. Requires Z to be observed during inference time.
    """

    def __init__(self, base_y_calibration: Calibration, alpha: float,
                 dataset_name: str, data_scaler: DataScaler,
                 data_mask_estimator: DataMaskEstimator, device, quick_mode=False):
        # quick_mode might produce invalid intervals, use with care!
        super().__init__(alpha)
        self.Qs = None
        self.data_mask_estimator = data_mask_estimator
        self.dataset_name = dataset_name
        self.data_scaler = data_scaler
        self.base_y_calibration = base_y_calibration
        self.marginal_missing_probability = None
        self.weights = None
        self.y2_cal_scores = None
        self.device = device
        self.quick_mode = quick_mode

    def get_mask_probabilities(self, scaled_x, scaled_z):
        return

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000,
            batch_size=64, n_wait=20,
            **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=epochs,
                    batch_size=batch_size, n_wait=batch_size, **kwargs)
        self.base_y_calibration.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs,
                                    batch_size,
                                    n_wait, **kwargs)
        self.data_mask_estimator.fit(x_train, z_train, deleted_train, x_val, z_val, deleted_val, epochs=epochs,
                                     batch_size=batch_size, n_wait=n_wait, **kwargs)

    def get_weight(self, conditional_missing_probability, get_noise=False):
        w = (1 - self.marginal_missing_probability) / (1 - conditional_missing_probability)
        if isinstance(self.data_mask_estimator, OracleDataMaskerWithDeltaMinMax):
            curr_seed = get_seed()
            set_seeds(0)
            noise = self.data_mask_estimator.generate_noise(w.shape[0], w.device)
            set_seeds(curr_seed)
            if not (
                    noise.min().item() >= self.data_mask_estimator.delta_min and noise.max().item() <= self.data_mask_estimator.delta_max):
                print("warning: noise exceeds min max thresholds, clipping it instead")
                noise = torch.clip(noise, min=self.data_mask_estimator.delta_min,
                                   max=self.data_mask_estimator.delta_max)
            w += noise
        elif isinstance(self.data_mask_estimator, OracleDataMaskerWithDelta):
            noise = torch.ones_like(w) * self.data_mask_estimator.delta
            w += noise
        else:
            noise = None
        if get_noise:
            return w, noise
        else:
            return w

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal, cal_prediction: ModelPrediction, **kwargs):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal, cal_prediction, **kwargs)
        self.data_mask_estimator.calibrate(x_cal, z_cal, deleted_cal)
        conditional_missing_probability = self.data_mask_estimator.predict(x_cal, z_cal)
        self.marginal_missing_probability = conditional_missing_probability.mean()
        self.weights, self.noise = self.get_weight(conditional_missing_probability, get_noise=True)
        if self.noise is not None:
            self.noise = self.noise[~deleted_cal]
        self.weights = self.weights[~deleted_cal]
        self.y2_cal_scores = self.base_y_calibration.compute_scores(x_cal, y_cal, cal_prediction)[~deleted_cal].detach()
        self.x_cal = x_cal
        self.z_cal = z_cal
        self.deleted_cal = deleted_cal

    def construct_calibrated_uncertainty_sets(self, x_test: torch.Tensor,
                                              test_prediction: ModelPrediction, **kwargs) -> UncertaintySets:
        if 'z_test' not in kwargs:
            raise Exception(f"could not calibrate with method {self.name} without 'z_test'")
        z_test = kwargs['z_test']

        thresholds = []
        missing_probabilities = self.data_mask_estimator.predict(x_test, z_test)
        test_weights = self.get_weight(missing_probabilities)
        device = x_test.device
        # self.store_delta_min_max_params(x_test, z_test)  # TODO: notice this
        if len(self.y2_cal_scores) == 0:
            thresholds = [torch.inf for _ in range(len(x_test))]
        else:
            max_score = self.y2_cal_scores.max().item()
            quantiles = torch.Tensor([1 - self.alpha]).to(x_test.device)
            values = torch.cat([self.y2_cal_scores, torch.tensor([max_score], device=device)])
            sorter = torch.argsort(values)
            values = values[sorter]
            weights_sum = self.weights.sum()
            for i in tqdm.tqdm(range(len(x_test))):
                w_test = test_weights[i]
                p_i = self.weights / (weights_sum + w_test)
                p_test = w_test / (weights_sum + w_test)
                sample_weight = torch.cat([p_i, torch.tensor([p_test.item()], device=device)])
                sample_weight = sample_weight[sorter]
                Q = weighted_quantile(values, quantiles, sample_weight=sample_weight, old_style=False,
                                      values_sorted=True).item()
                if self.quick_mode:
                    thresholds = torch.ones(len(x_test)) * Q
                    break
                else:
                    thresholds += [Q]
        thresholds = torch.Tensor(thresholds).to(x_test.device)
        test_calibrated_sets = self.base_y_calibration.compute_uncertainty_set_from_prediction_and_threshold(
            test_prediction, thresholds)

        return test_calibrated_sets

    def compute_scores(self, x, y, cal_prediction: ModelPrediction):
        return self.base_y_calibration.compute_scores(x, y, cal_prediction)

    def compute_uncertainty_set_from_prediction_and_threshold(self, test_prediction: ModelPrediction,
                                                              threshold) -> UncertaintySets:
        raise NotImplementedError("not implemented yet")

    def jackknife_plus_construct_uncertainty_set_from_scores(self, x_cal, y_cal, z_cal, deleted_cal,
                                                             cal_predictions: List[ModelPrediction],
                                                             x_test,
                                                             test_prediction: List[ModelPrediction],
                                                             z_test=None, **kwargs) -> UncertaintySets:
        if z_test is None:
            raise Exception(f"{self.name} must get z_test, but got {z_test}")
        missing_probabilities = self.data_mask_estimator.predict(x_cal, z_cal)
        marginal_missing_probability = missing_probabilities.mean()
        cal_weights = (1 - marginal_missing_probability) / (1 - missing_probabilities)[~deleted_cal]
        new_deleted_cal = torch.zeros(len(cal_weights)).to(cal_weights.device)
        test_missing_probabilities = self.data_mask_estimator.predict(x_test, z_test)
        test_weights = (1 - marginal_missing_probability) / (1 - test_missing_probabilities)
        cal_predictions = [cal_predictions[i] for i in (~deleted_cal).nonzero()]
        test_prediction = [test_prediction[i] for i in (~deleted_cal).nonzero()]
        return self.base_y_calibration.jackknife_plus_construct_uncertainty_set_from_scores(x_cal[~deleted_cal],
                                                                                            y_cal[~deleted_cal],
                                                                                            z_cal[~deleted_cal],
                                                                                            new_deleted_cal,
                                                                                            cal_predictions,
                                                                                            x_test,
                                                                                            test_prediction,
                                                                                            cal_weights=cal_weights,
                                                                                            test_weights=test_weights,
                                                                                            )

    def compute_performance(self, x_test, y, z_test, full_y_test, deleted_test, test_model_prediction: ModelPrediction,
                            **kwargs) -> dict:
        model = kwargs['model']
        data_generator = kwargs['data_generator']
        return {
            **self.data_mask_estimator.compute_performance(x_test, z_test, full_y_test, deleted_test),
            **self.base_y_calibration.compute_performance(x_test, y, z_test, full_y_test, deleted_test,
                                                          test_model_prediction, ),
            **self.compute_cond_x_coverage(x_test.shape[1], z_test.shape[1], model, data_generator)
        }

    def get_coverage_validity_params(self, x_test, z_test) -> dict:

        if len(x_test.shape) == 1:
            x_test = x_test.unsqueeze(0)
        if len(z_test.shape) == 1:
            z_test = z_test.unsqueeze(0)
        device = x_test.device
        conditional_missing_probability = self.data_mask_estimator.predict(self.x_cal, self.z_cal)
        marginal_missing_probability = conditional_missing_probability.mean()
        conditional_missing_probability = conditional_missing_probability[~self.deleted_cal]
        test_conditional_missing_probability = self.data_mask_estimator.predict(x_test, z_test)
        clean_calibration_weights = (1 - marginal_missing_probability) / (1 - conditional_missing_probability)
        test_weights = (1 - marginal_missing_probability) / (1 - test_conditional_missing_probability)

        # TODO: warning! the deltas are sampled randomly, so it might yield different weights for the same sample. If it is only used once, then its fine...
        # if isinstance(self.data_mask_estimator, OracleDataMaskerWithDeltaMinMax):
        #     w = w + self.data_mask_estimator.generate_noise(w.shape[0], w.device)
        #
        # elif isinstance(self.data_mask_estimator, OracleDataMaskerWithDelta):
        #     w = w + self.data_mask_estimator.delta
        if not isinstance(self.data_mask_estimator, OracleDataMaskerWithDeltaMinMax):
            return {}

        noise = self.noise  # self.data_mask_estimator.generate_noise(clean_calibration_weights.shape[0] + 1, device)
        test_noise = self.data_mask_estimator.generate_noise(1, device).reshape(1)
        noise = torch.cat([noise, test_noise])
        # noise = torch.clip(noise, min=self.data_mask_estimator.delta_min, max=self.data_mask_estimator.delta_max)

        scaled_noise = (noise - self.data_mask_estimator.delta_min) / \
                       (self.data_mask_estimator.delta_max - self.data_mask_estimator.delta_min)

        max_score = self.y2_cal_scores.max().item()
        quantiles = torch.Tensor([1 - self.alpha]).to(device)
        values = torch.cat([self.y2_cal_scores, torch.tensor([max_score], device=device)])
        sorter = torch.argsort(values)
        values = values[sorter]
        scaled_noise = scaled_noise[sorter]
        weights_sum = clean_calibration_weights.sum()
        i = 0  # (test_weights > test_weights.median()).float().argmax()
        w_test = test_weights[i]
        p_i = clean_calibration_weights / (weights_sum + w_test)
        p_test = w_test / (weights_sum + w_test)
        scaled_sample_weight = torch.cat([p_i, torch.tensor([p_test.item()], device=device)])
        scaled_sample_weight = scaled_sample_weight[sorter]
        Q = weighted_quantile(values, quantiles, sample_weight=scaled_sample_weight, old_style=False,
                              values_sorted=True).item()
        k = (Q >= values).float().argmin().item()
        sum_clean_p_i = scaled_sample_weight[values <= Q].sum().item()

        assert sum_clean_p_i == scaled_sample_weight[:k].sum().item()

        n = len(clean_calibration_weights)
        sample_weight = torch.cat([clean_calibration_weights, w_test.unsqueeze(0)])[sorter]
        delta_tilde_n_1 = (scaled_noise.sum()).item()
        delta_tilde_k = (scaled_noise[:k].sum()).item()
        params = {"k_wcp": k, "k": k, 'n': n, 'C_k': sample_weight[:k].sum().item(),
                  'C_n_1': (sample_weight.sum()).item(),
                  'delta_tilde_k': delta_tilde_k,
                  'delta_tilde_n_1': delta_tilde_n_1,
                  'k_cp': int(np.ceil((1 - self.alpha + 1 / (n + 1)) * n)),
                  'Q_clean_wcp': Q,
                  'sum_clean_p_i': sum_clean_p_i,
                  }
        return params

    # def store_delta_min_max_params(self, x_test, z_test):
    #     tmp_data_mask_estimator = copy.deepcopy(self.data_mask_estimator)
    #     tmp_data_mask_estimator.delta_min = 0
    #     tmp_data_mask_estimator.delta_max = 1
    #     dest_dir = os.path.join('delta_exp', self.dataset_name, tmp_data_mask_estimator.name)
    #     dest_file_path = os.path.join(dest_dir, 'params.csv')
    #     if os.path.exists(dest_file_path):
    #         return
    #     os.makedirs(dest_dir, exist_ok=True)
    #     params = self.get_coverage_validity_params(x_test, z_test)
    #     for k, v in params.items():
    #         params[k] = [v]
    #     df = pd.DataFrame(params)
    #     df.to_csv(dest_file_path)
    #     exit(0)

    def compute_cond_x_coverage(self, x_dim, z_dim, model: PredictionIntervalModel,
                                data_generator: SyntheticDataGenerator):
        if data_generator is None or not isinstance(model, PredictionIntervalModel):
            return {}
        if not self.quick_mode:
            return {}
        if not isinstance(self.data_mask_estimator, OracleDataMaskerWithDeltaMinMax):
            return {}
        result = {}

        # max_score = self.y2_cal_scores.max().item()
        # quantiles = torch.Tensor([1 - self.alpha]).to(self.device)
        # values = torch.cat([self.y2_cal_scores, torch.tensor([max_score], device=self.device)])
        # sorter = torch.argsort(values)
        # values = values[sorter]
        # weights_sum = self.weights.sum()

        for x_name, cond_x, cond_z in [
            ('x1', [2.6752, 1.2141, 2.0997, 4.4819, 3.9244, 4.1068, 4.9509, 1.9368, 4.8397, 1.6686],
             [-2.9365, -3.4784, 1.3291]),
            ('x2', [3.7146, 2.5531, 2.4233, 2.2607, 2.0058, 4.1171, 2.2599, 1.4794, 1.7738, 4.5323],
             [1.3932, 0.5212, 2.7791]),
            ('x3', [3.8999, 1.7390, 4.0567, 2.0334, 3.7715, 2.1359, 4.0610, 3.4683, 1.3468, 2.4591],
             [1.5456, 1.8737, 0.9441]),
            ('x4', [3.1712, 2.8301, 3.4082, 1.9410, 4.0644, 2.8850, 1.9367, 1.7227, 3.6824, 3.7627],
             [9.0985e-01, 3.7604e-03, -4.8942e-01]),
            ('x5', [1.3342, 1.3362, 1.1497, 4.8499, 1.0335, 3.4199, 1.4321, 3.7218, 4.5812, 4.2353],
             [-2.7544e+00, -2.2064e-01, 1.0366e+00]),
            ('x6', [1.0400, 1.5751, 2.5961, 3.0510, 2.0910, 4.2115, 3.1603, 1.6593, 3.1852, 1.3479],
             [1.9617e+00, 4.7603e+00, 4.0545e-01]),
            ('x7', [2.8510, 4.0131, 3.5350, 3.8528, 2.8242, 1.5443, 3.9887, 4.0634, 1.9717, 4.7869],
             [-1.7615e-01, -2.5290e+00, -3.9254e+00]),
            ('x8', [2.1610, 3.7517, 4.3001, 3.4903, 2.2104, 1.3324, 3.8821, 1.8632, 3.3865, 2.9740],
             [2.0463e+00, 2.0782e+00, -2.5641e+00]),
            ('x9', [3.8826, 1.8007, 1.7191, 2.5847, 1.4165, 2.1474, 4.5465, 3.7812, 2.4952, 4.7338],
             [1.1314e+00, 5.0326e+00, 1.3640e+00]),
        ]:
            cond_x = torch.Tensor(cond_x).reshape(1, x_dim).to(self.device).to(torch.float32)
            cond_x = self.data_scaler.scale_x(cond_x)
            cond_z = torch.Tensor(cond_z).reshape(1, z_dim).to(self.device).to(torch.float32)
            cond_z = self.data_scaler.scale_z(cond_z)

            cond_ys = data_generator.get_y_given_x_z(cond_x, cond_z, repeats=100000, seed=42)
            cond_ys = self.data_scaler.scale_y(cond_ys)
            uncalibrated_intervals = model.construct_uncalibrated_intervals(cond_x)
            interval = self.construct_calibrated_uncertainty_sets(cond_x, z_test=cond_z,
                                                                  test_prediction=uncalibrated_intervals).intervals.squeeze()
            cond_coverage = ((cond_ys.squeeze() <= interval[..., 1].squeeze()) & (
                    cond_ys.squeeze() >= interval[..., 0].squeeze())).float().mean().item()
            coverage_validity_params = self.get_coverage_validity_params(cond_x, cond_z)
            result[f'x_{x_name}_coverage'] = cond_coverage
            Q_clean_wcp = coverage_validity_params['Q_clean_wcp']
            calibrated_intervals_with_clean_weights = self.base_y_calibration.compute_uncertainty_set_from_prediction_and_threshold(
                uncalibrated_intervals, Q_clean_wcp).intervals
            clean_wcp_cond_coverage = (
                    (cond_ys.squeeze() <= calibrated_intervals_with_clean_weights[..., 1].squeeze()) & (
                    cond_ys.squeeze() >= calibrated_intervals_with_clean_weights[
                ..., 0].squeeze())).float().mean().item()
            for k, v in coverage_validity_params.items():
                result[f'x_{x_name}_{k}'] = v
            result[f'x_{x_name}_clean_wcp_cond_coverage'] = clean_wcp_cond_coverage

            # assert isinstance(self.data_mask_estimator, OracleDataMaskerWithDeltaMinMax)
            #
            # delta_min = self.data_mask_estimator.delta_min
            # delta_max = self.data_mask_estimator.delta_max
            #
            # k = coverage_validity_params[f'k']
            # n = coverage_validity_params[f'n']
            # k_cp = coverage_validity_params[f'k_cp']
            # Q_clean_wcp = coverage_validity_params[f'Q_clean_wcp']
            # C_k = coverage_validity_params[f'C_k']
            # C_n_1 = coverage_validity_params[f'C_n_1']
            # delta_tilde_k = coverage_validity_params[f'delta_tilde_k']
            # delta_tilde_n_1 = coverage_validity_params[f'delta_tilde_n_1']
            # sum_clean_p_i = coverage_validity_params[f'sum_clean_p_i']
            # delta_k = delta_tilde_k * (delta_max - delta_min) + delta_min * k
            # delta_n_1 = delta_tilde_n_1 * (delta_max - delta_min) + delta_min * (n+1)
            #
            # # missing_probability = self.data_mask_estimator.predict(cond_x, cond_z)
            # # w_test = self.get_weight(missing_probability)
            # # p_i = self.weights / (weights_sum + w_test)
            # # p_test = w_test / (weights_sum + w_test)
            # # sorted_weights = torch.cat([self.weights, torch.tensor([w_test.item()], device=self.device)])[sorter]
            # # sample_weight = torch.cat([p_i, torch.tensor([p_test.item()], device=self.device)])
            # # sample_weight = sample_weight[sorter]
            # # Q_noised_wcp = weighted_quantile(values, quantiles, sample_weight=sample_weight, old_style=False,
            # #                       values_sorted=True).item()
            #
            # # k_noised_wcp = (values <= Q_noised_wcp ).float().argmin().item()
            # # assert np.ceil(weighted_quantile(torch.arange(len(values)).to(self.device), quantiles, sample_weight=sample_weight, old_style=False,
            # #                       values_sorted=True).item()) == k_noised_wcp
            #
            # ((C_k + delta_k) / (C_n_1 + delta_n_1) )<= C_k / C_n_1
            # (C_n_1 + delta_n_1) > 0
            # sample_weight[:k].sum().item() <= sum_clean_p_i
            # delta = (delta_min / (delta_max - delta_min) )
            # req2 = delta <= (delta_tilde_n_1 * C_k - delta_tilde_k * C_n_1) / (C_n_1 * k - (n+1)*C_k)
            # req3 = delta_min > - (delta_max * delta_tilde_n_1 + C_n_1) / (n+1 - delta_tilde_n_1)
            # req2 ^ req3
            # coverage_validity_params['C_n_1'] + (coverage_validity_params['delta_tilde_n_1'] * (delta_max - delta_min) + delta_min)

        return result

    @property
    def name(self):
        return f"weighted_{self.base_y_calibration.name}_{self.data_mask_estimator.name}_masker"
