# NOTE: might be computationally expensive with the generic "solve" and "robust_solve" method.

import numpy as np
from model.v1.predictor import BaseModel 

class BaseOptimization:
    '''Minimization problem'''
    def __init__(self, model : BaseModel, feasible_region = None, outcome_space = None):
        '''
        - model:            a UQ model with the "contained" method that is able to construct the uncertainty set
        - feasible_region:  [ :, nZ ] np, enumeration of the feasible solution set points
        - outcome_space:    [ :, nY ] np, enumeration of the outcome space points
        '''
        self.model = model
        self.feasible_region    = feasible_region
        self.outcome_space      = outcome_space

    def regret(self, x, y, lam):
        '''
        # NOTE: consider changing this to other close-form expression if exist 
        Args:
        - x:    [ nbatch, nX ] np
        - y:    [ nbatch, nY ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:   [ nbatch ] np
        '''
        z_ro_opt, _     = self.robust_solve(x, lam)    # [ nbatch, nZ ], [ nbatch ] np
        z_opt, _        = self.solve(y)                # [ nbatch, nZ ], [ nbatch ] np
        loss            = self.objective(y, z_ro_opt) - self.objective(y, z_opt)    # [ nbatch ] np  
        assert loss.min() >= 0, f'Violation of loss rule by {loss.min()}. Perhaps discretize resolution too low?'
        return loss

    def miscoverage(self, x, y, lam):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        - y:    [ nbatch, nY ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:   [ nbatch ] np
        '''
        loss   = 1 - self.model.contain(x, y, lam)   # [ nbatch ] np
        return loss

    def objective(self, y, z):
        '''
        Args:
        - y:    [ nbatch, nY ] np
        - z:    [ nbatch, nZ ] np
        Return:
        - obj:   [ nbatch ] np
        '''
        raise NotImplementedError()

    def robust_solve(self, x, lam):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        - lam:  [ nbatch ] np
        Return:
        - z:    [ nbatch, nZ ] np
        - loss:   [ nbatch ] np
        '''
        nbatch      = x.shape[0]
        noutcome    = self.outcome_space.shape[0]
        nfeasible   = self.feasible_region.shape[0] 

        i, j    = np.meshgrid(np.arange(nbatch), np.arange(noutcome), indexing='ij')
        x_ext   = x[i.ravel()]                      # [ nbatch*noutcome, nX ] np
        lam_ext = lam[i.ravel()]                    # [ nbatch*noutcome ] np
        y_ext   = self.outcome_space[j.ravel()]     # [ nbatch*noutcome, nY ] np
        mask    = self.model.contain(x_ext, y_ext, lam_ext)  # [ nbatch*noutcome ] np 
        assert mask.reshape(nbatch, noutcome).any(1).all(0) == True, 'Check contain method. Perhaps input lam is too small?'

        y       = y_ext[mask]   # [ :, nY ] np
        k, l    = np.meshgrid(np.arange(y.shape[0]), np.arange(nfeasible), indexing='ij')
        y_ext   = y[k.ravel()]  # [ : * nfeasible, nY ] np
        z_ext   = self.feasible_region[l.ravel()]  # [ : * nfeasible, nY ] np
        obj     = self.objective(y_ext, z_ext)    # [ : * nfeasible ] np

        out     = np.ones([nbatch, noutcome, nfeasible]) * -np.inf           # [ nbatch, noutcome, nfeasible ] np
        out[i.ravel()[mask][k.ravel()], j.ravel()[mask][k.ravel()], l.ravel()] = obj

        robust_obj = out.max(1)                 # [ nbatch, nfeasible ] np
        index   = robust_obj.argmin(1)          # [ nbatch ] np
        z_opt   = self.feasible_region[index]   # [ nbatch, nZ ] np
        loss    = robust_obj.min(1)             # [ nbatch ] np
        assert loss.max() < np.inf and loss.min() > -np.inf
        return z_opt, loss

    def solve(self, y):
        '''
        Args:
        - y:    [ nbatch, nY ] np
        Return:
        - z:    [ nbatch, nZ ] np
        '''
        nbatch      = y.shape[0]
        nfeasible   = self.feasible_region.shape[0] 

        i, j    = np.meshgrid(np.arange(nbatch), np.arange(nfeasible), indexing='ij')
        y_ext   = y[i.ravel()]                      # [ nbatch*nfeasible, nX ] np
        z_ext   = self.feasible_region[j.ravel()]   # [ nbatch*nfeasible, nZ ] np
        obj     = self.objective(y_ext, z_ext)      # [ nbatch*nfeasible ] np
        obj     = obj.reshape(nbatch, nfeasible)    # [ nbatch, nfeasible ] np
        
        index   = obj.argmin(1)                 # [ nbatch ] np
        z_opt   = self.feasible_region[index]   # [ nbatch, nZ ] np
        loss    = obj.min(1)                    # [ nbatch ] np
        return z_opt, loss