import numpy as np
from model.v1.optimization import BaseOptimization

class Newsvendor(BaseOptimization):

    def __init__(self, model, p, v, c, feasible_region = None, outcome_space = None):
        super().__init__(model, feasible_region, outcome_space)
        self.p = p
        self.v = v
        self.c = c

    def objective(self, y, z):
        '''
        - p * min(y, z) + c * z - v * (z - y)^+,
        Args:
        - y:    [ nbatch, 1 ] np
        - z:    [ nbatch, 1 ] np
        Return:
        - obj:   [ nbatch ] np
        '''
        obj = - self.p * np.min(np.concatenate([y, z], axis=1), axis=1) + self.c * z.reshape(-1) - self.v * np.clip(z - y, a_min=0, a_max=None).reshape(-1)
        return obj
    
    # NOTE: this is the more computationally efficient close form solution of the regret
    def regret(self, x, y, lam):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        - y:    [ nbatch, nY ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:   [ nbatch ] np
        '''
        # loss = (self.c - self.p) * (self.model.pred(x) - lam.reshape(-1, 1) - y)
        loss = self.p * (y - np.min(np.concatenate([y, self.model.pred(x) - lam.reshape(-1, 1)], axis=1), axis=1).reshape(-1, 1)) - self.c * (y - (self.model.pred(x) - lam.reshape(-1, 1)))
        assert loss.min() >= 0
        return loss.reshape(-1)