import torch

from models.qr_models.PredictionIntervalModel import PredictionIntervalModel, PredictionIntervals


class BadQuantileRegression(PredictionIntervalModel):

    def __init__(self, alpha: float):
        PredictionIntervalModel.__init__(self, alpha)

    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        pass

    def eval(self):
        pass

    def construct_uncalibrated_intervals(self, x: torch.Tensor) -> PredictionIntervals:
        q_high = torch.zeros(x.shape[0]).to(x.device)
        q_low = torch.zeros(x.shape[0]).to(x.device)
        interval = torch.cat([q_low.unsqueeze(-1), q_high.unsqueeze(-1)], dim=-1)
        return PredictionIntervals(interval)

    @property
    def name(self) -> str:
        return "bad_qr"
