import numpy as np
from model.v1.optimization import BaseOptimization
import torch
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

class LinearProgramming(BaseOptimization):

    def __init__(self, model, feasible_region, outcome_space, A, b):
        '''
        A:  [ nconst, ndim ]
        b:  [ nconst ]
        '''
        self.A = A
        self.b = b
        super().__init__(model, feasible_region, outcome_space)

        ndim    = self.A.shape[1]
        nconst  = self.A.shape[0] 
        A = self.A
        b = self.b
        # -------------
        # Defining the robust LP
        # -------------
        z = cp.Variable(ndim)
        t = cp.Variable(ndim)
        mu = cp.Parameter(ndim)             # prediction center
        rho = cp.Parameter(nonneg=True)     # l_infty ball

        objective = cp.Minimize(mu @ z + rho * cp.sum(t))
        constraints = [
            A @ z <= b,
            -t <= z,
            z <= t,
            t >= 0
        ]
        problem = cp.Problem(objective, constraints)
        self.robust_solver = CvxpyLayer(problem, parameters=[mu, rho], variables=[z, t])

        # -------------
        # Defining the standard LP
        # -------------
        z = cp.Variable(ndim)
        y = cp.Parameter(ndim)             # prediction center
        objective = cp.Minimize(y @ z)
        constraints = [
            A @ z <= b
        ]
        # Define problem
        problem = cp.Problem(objective, constraints)
        self.solver = CvxpyLayer(problem, parameters=[y], variables=[z])

    def objective(self, y, z):
        '''
        Args:
        - y:    [ nbatch, ndim ] np
        - z:    [ nbatch, ndim ] np
        Return:
        - obj:   [ nbatch ] np
        '''
        obj = y @ z.T           # [ nbatch, nbatch ] np
        obj = obj.diagonal()    # [ nbatch ] np
        return obj
    
    # NOTE: assume that the UQ set is l_infty ball
    # NOTE: efficient implementation of the regret computation
    def regret(self, x, y, lam):
        '''
        Args:
        - x:    [ nbatch, ndim ] np
        - y:    [ nbatch, ndim ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:     [ nbatch ] np
        '''
        # Init
        y_batch     = torch.from_numpy(y).float()                   # [ nbatch, ndim ] 
        mu_batch    = torch.from_numpy(self.model.pred(x)).float()  # [ nbatch, ndim ] torch
        rho_batch   = torch.from_numpy(lam).float()                 # [ nbatch ] torch
        # Solve
        z_ro_opt, _ = self.robust_solver(mu_batch, rho_batch) # [ nbatch, ndim ] torch
        z_opt       = self.solver(y_batch)                    # [ nbatch, ndim ] torch
        z_ro_opt    = z_ro_opt.numpy()
        z_opt       = z_opt[0].numpy()
        loss        = self.objective(y, z_ro_opt) - self.objective(y, z_opt)    # [ nbatch ] np  
        assert loss.min() >= -1e-3, f'Violation of loss rule by {loss.min()}. Perhaps increase solver accuracy?'
        return loss

# helper
def circle_as_polytope(R=1.0, m=32):
    '''
    Approximate a circle with m-sided polygon via Az <= b
    NOTE: example: A, b = circle_as_polytope(R=1.0, m=16)
    '''
    thetas = np.linspace(0, 2*np.pi, m, endpoint=False)
    A = np.c_[np.cos(thetas), np.sin(thetas)]  # [m, 2]
    b = np.ones(m) * R
    return A, b