
import numpy as np
from model.v1.optimization import BaseOptimization

class ConformalRegretControl:

    def __init__(self, optimization : BaseOptimization):
        self.optimzation = optimization

    def estimate(self, x, y, lam, B, also_mc = False):
        '''
        Args:
        - x:    [ nbatch, nX ]
        - y:    [ nbatch, nX ]
        - lam:  scalar
        - B:    upper bound of the regret 
        '''
        n = x.shape[0]
        lams = np.ones([n]) * lam

        reg_sum = self.optimzation.regret(x, y, lams).sum()
        mis_sum = self.optimzation.miscoverage(x, y, lams).sum()

        alpha_R = reg_sum / (n + 1) + B / (n + 1)    # scalar 
        alpha_I = mis_sum / (n + 1) + 1 / (n + 1)    # scalar

        if also_mc: # baseline method
            alpha_R_mc = reg_sum / n    # scalar 
            alpha_I_mc = mis_sum / n    # scalar
            return alpha_R, alpha_I, alpha_R_mc, alpha_I_mc
        else:
            return alpha_R, alpha_I
