from typing import List

import torch

from error_sampler.ErrorSampler import ErrorSampler
from models.qr_models.QuantileRegression import QuantileRegression


class QRErrorSampler(ErrorSampler):

    def  __init__(self, dataset_name: str, saved_models_path, x_dim: int, y_dim: int, z_dim: int,
                 scaled_y_min: float, scaled_y_max: float,
                 hidden_dims: List[int] = None, dropout: float = 0.1,
                 batch_norm: bool = False,
                 lr: float = 1e-3, wd: float = 0., device='cpu', figures_dir=None,
                 seed=0):
        super().__init__()
        self.qr = QuantileRegression(dataset_name, saved_models_path, z_dim, y_dim, 0,
                               hidden_dims=hidden_dims, dropout=dropout, lr=lr, wd=wd,
                                 batch_norm=batch_norm, train_all_q=True,
                               device=device, figures_dir=figures_dir, seed=seed,
                               scaled_y_min=scaled_y_min, scaled_y_max=scaled_y_max)

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        self.qr.fit(z_train, errors_train.detach(), deleted_train, z_val, errors_val.detach(), deleted_val, **kwargs)

    def sample_error(self, x_test, z_test):
        cdf, inv_cdf = self.qr.get_quantile_function(z_test)
        rnd_quantile_levels = torch.rand(len(z_test), 1, device=z_test.device)
        return inv_cdf(rnd_quantile_levels).squeeze()

    @property
    def name(self) -> str:
        return "nn_error_sampler"
