# https://github.com/craig-m-k/Recursive-least-squares/blob/master/RLS.ipynb

import numpy as np
import math

import pdb

class splitConformal:
    def __init__(self, num_groups, groups, delta):

        # num_groups: number of groups we are considering
        self.num_groups = num_groups

        # groups: list of functions defining the groups
        self.groups = groups

        # calibration_sets_xs: calibration data (input) corresponding to each separate group
        # calibration_sets_ys: calibration data (labels) corresponding to each separate group
        self.calibration_sets_xs = [None] * num_groups
        self.calibration_sets_ys = [None] * num_groups

        # delta: determines desired coverage probability (1 - delta)
        self.delta = delta
    
    def update_calibration_data(self, new_x, new_y):
        for i, group in enumerate(self.groups):
            if group(new_x):
                if self.calibration_sets_xs[i] is None:
                    self.calibration_sets_xs[i] = new_x
                else:
                    self.calibration_sets_xs[i] = np.vstack((self.calibration_sets_xs[i], new_x))
                if self.calibration_sets_ys[i] is None:
                    self.calibration_sets_ys[i] = np.array(new_y)
                else:
                    self.calibration_sets_ys[i] = np.hstack((self.calibration_sets_ys[i], new_y))

    def select_best_width(self, scorer, new_x):
        # Select the most conservative width from all generated by
        # calibration datasets for relevant groups
        curr_width = -1 * float('inf')
        for i, group in enumerate(self.groups):
            if group(new_x):
                calibration_set_x = self.calibration_sets_xs[i]
                calibration_set_y = self.calibration_sets_ys[i]
                residuals = scorer.calc_score(calibration_set_x, calibration_set_y)
                calibration_size = len(calibration_set_x)
                desired_quantile = np.ceil((1 - self.delta) * (calibration_size + 1)) / calibration_size
                chosen_quantile = np.minimum(1.0, desired_quantile)
                w_t = np.quantile(residuals, chosen_quantile)
                curr_width = np.maximum(curr_width, w_t)

        return curr_width
    
    def all_groups_covered(self):
        # Check to see if we have calibration data for all groups
        for i, _ in enumerate(self.groups):
            if self.calibration_sets_xs[i] is None:
                return False
        return True