# Sampling-based calibrators

# ---------------------------------------------------------------------------- #
#                              Logprobs Predictors                             #
# ---------------------------------------------------------------------------- #
geometric_quantile_logprobs:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.GeometricQuantilePredictor
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
  input_key: normalized_prompt_logprobs
geometric_quantile_logprobs_kmax+1:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.GeometricQuantilePredictor
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
    nonadmissible_handling: k_max
  input_key: normalized_prompt_logprobs

# ---------------------------------------------------------------------------- #
#                            Hidden State predictors                           #
# ---------------------------------------------------------------------------- #
geometric_quantile_hidden_states:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.GeometricQuantilePredictor
      base_predictor:
        _target_: generative_prediction_sets.predictors.MlpPredictor
        checkpoint: ${config.train}/model.ckpt
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
  input_key: hidden_states 
geometric_quantile_hidden_states_kmax+1:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.GeometricQuantilePredictor
      base_predictor:
        _target_: generative_prediction_sets.predictors.MlpPredictor
        checkpoint: ${config.train}/model.ckpt
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
    nonadmissible_handling: k_max
  input_key: hidden_states 

# ---------------------------------------------------------------------------- #
#                                   Baselines                                  #
# ---------------------------------------------------------------------------- #
k_min_constant:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.ConstantPredictor
      value: 0
      ndim: 2
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
  input_key: normalized_prompt_logprobs 

geometric_quantile_oracle:
  calibrator:
    _target_: generative_prediction_sets.calibrators.StoppingRuleCalibrator
    predictor:
      _target_: generative_prediction_sets.predictors.GeometricQuantilePredictor
      base_predictor:
        _target_: generative_prediction_sets.predictors.PassThroughPredictor
        flatten: true 
    score:
      _target_: generative_prediction_sets.regression_scores.CQRScore
      side: "upper"
  input_key: p_map 

