import torch

from data_utils.data_scaler import DataScaler
from data_utils.dataset_naming_utils import get_original_dataset_name
from data_utils.get_dataset_utils import get_data_generator
from models.qr_models.PredictionIntervalModel import PredictionIntervalModel, PredictionIntervals


class OracleQuantileRegression(PredictionIntervalModel):

    def __init__(self, dataset_name: str, original_dataset_name: str, x_dim: int, y_dim: int, z_dim: int, alpha: float, data_scaler: DataScaler,
                 device='cpu', seed=0, repeat_size=500):
        PredictionIntervalModel.__init__(self, alpha)
        self.repeat_size = 500
        self.data_generator = get_data_generator(dataset_name, x_dim, z_dim)
        self.device = device
        self.data_scaler = data_scaler
        self.seed = seed
        self.dataset_name = dataset_name

    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        pass
        # corrected_alpha = self.alpha - 1 / self.repeat_size
        # alpha_high = 1 - corrected_alpha / 2
        # alpha_low = corrected_alpha / 2
        # q_high = self.__get_conditional_quantile(x_train, alpha_high)
        # q_low = self.__get_conditional_quantile(x_train, alpha_low)
        # print((y_train.squeeze() <= q_high.squeeze()).float().mean())
        # print((y_train.squeeze() <= q_low.squeeze()).float().mean())
        # print()

    def eval(self):
        pass

    def get_data_given_x(self, unscaled_x, repeats, seed):
        if self.dataset_name.endswith("_z"):
            return self.data_generator.generate_z_given_x(unscaled_x, repeats=repeats, seed=seed)
        else:
            return self.data_generator.get_y_given_x(unscaled_x, repeats=repeats, seed=seed)

    def __get_conditional_quantile(self, x: torch.Tensor, quantile_level: float) -> torch.Tensor:
        unscaled_x = self.data_scaler.unscale_x(x)
        sample = self.get_data_given_x(unscaled_x, repeats=self.repeat_size, seed=self.seed)
        # scaled_sample = torch.stack(self.data_scaler.scale_y(*[sample[i] for i in range(sample.shape[0])])).squeeze()
        quantiles = torch.quantile(sample, q=quantile_level, dim=0)
        scaled_quantiles = self.data_scaler.scale_y(quantiles)
        return scaled_quantiles

    def construct_uncalibrated_intervals(self, x: torch.Tensor) -> PredictionIntervals:
        corrected_alpha = self.alpha - 1 / self.repeat_size
        alpha_high = 1 - corrected_alpha / 2
        alpha_low = corrected_alpha / 2
        q_high = self.__get_conditional_quantile(x, alpha_high)
        q_low = self.__get_conditional_quantile(x, alpha_low)
        interval = torch.cat([q_low.unsqueeze(-1), q_high.unsqueeze(-1)], dim=-1)
        return PredictionIntervals(interval)

    @property
    def name(self) -> str:
        return "oracle_qr"
