import torch
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import numpy as np
from model.v1.optimization import BaseOptimization

class PortfolioOptimization(BaseOptimization):

    def __init__(self, model, feasible_region, outcome_space, gamma, alpha, ndim, nsample):
        '''
        Args:
        - gamma: scalar
        - alpha: scalar
        '''
        feasible_region = self.prune_feasible_region(feasible_region)
        super().__init__(model, feasible_region, outcome_space)
        self.gamma = gamma
        self.alpha = alpha

        # ---------
        # Normal optimization
        # ---------
        # Variables
        z = cp.Variable(ndim)       # portfolio weights
        t = cp.Variable()           # VaR level
        u = cp.Variable(nsample)    # slack vars for CVaR
        # Parameters
        Y  = cp.Parameter((nsample, ndim))     # return scenarios [nsample x ndim]
        y  = cp.Parameter(ndim)                # expected returns
        gamma = self.gamma    # risk aversion
        alpha = self.alpha

        cvar_expr = t + (1.0 / ((1 - alpha) * nsample)) * cp.sum(u)
        objective = cp.Minimize(-y @ z + gamma * cvar_expr)
        constraints = [
            cp.sum(z) == 1,
            z >= 0,
            u >= -Y @ z - t,
            u >= 0
        ]
        problem = cp.Problem(objective, constraints)
        self.solver = CvxpyLayer(problem, parameters=[Y, y], variables=[z, t, u])
        
        # ---------
        # Robust optimization
        # ---------
        # Variables
        z = cp.Variable(ndim)       # portfolio weights
        t = cp.Variable()           # VaR level
        u = cp.Variable(nsample)    # slack vars for CVaR
        # Parameters
        Y  = cp.Parameter((nsample, ndim))     # return scenarios [nsample x ndim]
        ypred  = cp.Parameter(ndim)                # expected returns
        lam = cp.Parameter(nonneg = True)      # uncertainty radius
        gamma = self.gamma    # risk aversion
        alpha = self.alpha

        cvar_expr = t + (1.0 / ((1 - alpha) * nsample)) * cp.sum(u)
        objective = cp.Minimize(-ypred @ z + gamma * cvar_expr + lam * cp.norm1(z))
        constraints = [
            cp.sum(z) == 1,
            z >= 0,
            u >= -Y @ z - t,
            u >= 0
        ]
        problem = cp.Problem(objective, constraints)
        self.robust_solver = CvxpyLayer(problem, parameters=[Y, ypred, lam], variables=[z, t, u])

    def objective(self, y, z):
        '''
        Args:
        - y:    [ nbatch, nY = nZ ] np
        - z:    [ nbatch, nZ = nY ] np
        Return:
        - obj:   [ nbatch ] np
        '''
        R       = np.inner(z, y).diagonal()     # [ nbatch, nbatch ] -> [ nbatch ] np
        cvar    = self.CVaR(self.alpha, z, y)   # [ nbatch ] np 
        obj     = R - self.gamma * cvar # [ nbatch ] np
        return obj
    
    # # NOTE: this is the more computationally efficient close form solution of the regret
    # # NOTE: only an approximate objective for robust version!
    # def regret(self, x, y, lam):
    #     '''
    #     Args:
    #     - x:    [ nbatch, nX ] np
    #     - y:    [ nbatch, nY ] np
    #     - lam:  [ nbatch ] np
    #     Return:
    #     - loss:   [ nbatch ] np
    #     '''
    #     # -------------
    #     # Init
    #     # -------------
    #     y_batch     = torch.from_numpy(y).float()   # [ nbatch, ndim ] 
    #     Y_batch     = y_batch.unsqueeze(0).expand(y_batch.shape[0], -1, -1)
    #     ypred_batch    = torch.from_numpy(self.model.pred(x)).float()   # [ nbatch, ndim ] torch
    #     lam_batch   = torch.from_numpy(lam).float()                     # [ nbatch ] torch
    #     # -------------
    #     # Solve
    #     # -------------
    #     z_ro_opt, t_ro_opt, u_ro_opt    = self.robust_solver(Y_batch, ypred_batch, lam_batch)    # [ nbatch, ndim ] torch
    #     z_opt, t_opt, u_opt             = self.solver(Y_batch, y_batch)                          # [ nbatch, ndim ] torch
    #     nsample = y.shape[0]
    #     obj = (-y_batch @ z_opt.T).diag() + self.gamma * (t_opt.squeeze() + (1/((1-self.alpha) * nsample))*u_opt.sum(dim=1))
    #     obj_ro = (-y_batch @ z_ro_opt.T).diag() + self.gamma * (t_ro_opt.squeeze() + (1/((1-self.alpha) * nsample))*u_ro_opt.sum(dim=1)) + lam_batch * torch.linalg.norm(z_ro_opt, ord=1, dim=1)
        
    #     loss        = obj_ro - obj    # [ nbatch ] torch  
    #     loss        = loss.numpy()    # [ nbatch ] np 
    #     assert loss.min() >= -1e-3, f'Violation of loss rule by {loss.min()}. Perhaps increase solver accuracy?'
    #     return loss
    
    # NOTE: this is the more computationally efficient close form solution of the regret
    # NOTE: only an approximate objective for robust version!
    def regret(self, x_, y_, lam_):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        - y:    [ nbatch, nY ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:   [ nbatch ] np
        '''
        nsample     = y_.shape[0] 
        ndim        = y_.shape[1]
        # ---------
        # Normal optimization
        # ---------
        # Variables
        z = cp.Variable(ndim)       # portfolio weights
        t = cp.Variable()           # VaR level
        u = cp.Variable(nsample)    # slack vars for CVaR
        # Parameters
        Y  = cp.Parameter((nsample, ndim))     # return scenarios [nsample x ndim]
        y  = cp.Parameter(ndim)                # expected returns
        gamma = self.gamma    # risk aversion
        alpha = self.alpha

        cvar_expr = t + (1.0 / ((1 - alpha) * nsample)) * cp.sum(u)
        objective = cp.Minimize(-y @ z + gamma * cvar_expr)
        constraints = [
            cp.sum(z) == 1,
            z >= 0,
            u >= -Y @ z - t,
            u >= 0
        ]
        problem = cp.Problem(objective, constraints)
        self.solver = CvxpyLayer(problem, parameters=[Y, y], variables=[z, t, u])
        
        # ---------
        # Robust optimization
        # ---------
        # Variables
        z = cp.Variable(ndim)       # portfolio weights
        t = cp.Variable()           # VaR level
        u = cp.Variable(nsample)    # slack vars for CVaR
        # Parameters
        Y  = cp.Parameter((nsample, ndim))     # return scenarios [nsample x ndim]
        ypred  = cp.Parameter(ndim)                # expected returns
        lam = cp.Parameter(nonneg = True)      # uncertainty radius
        gamma = self.gamma    # risk aversion
        alpha = self.alpha

        cvar_expr = t + (1.0 / ((1 - alpha) * nsample)) * cp.sum(u)
        objective = cp.Minimize(-ypred @ z + gamma * cvar_expr + lam * cp.norm1(z))
        constraints = [
            cp.sum(z) == 1,
            z >= 0,
            u >= -Y @ z - t,
            u >= 0
        ]
        problem = cp.Problem(objective, constraints)
        self.robust_solver = CvxpyLayer(problem, parameters=[Y, ypred, lam], variables=[z, t, u])


        # -------------
        # Init
        # -------------
        y_batch     = torch.from_numpy(y_).float()   # [ nbatch, ndim ] 
        Y_batch     = y_batch.unsqueeze(0).expand(y_batch.shape[0], -1, -1)
        ypred_batch    = torch.from_numpy(self.model.pred(x_)).float()   # [ nbatch, ndim ] torch
        lam_batch   = torch.from_numpy(lam_).float()                     # [ nbatch ] torch
        # -------------
        # Solve
        # -------------
        z_ro_opt, t_ro_opt, u_ro_opt    = self.robust_solver(Y_batch, ypred_batch, lam_batch)    # [ nbatch, ndim ] torch
        z_opt, t_opt, u_opt             = self.solver(Y_batch, y_batch)                          # [ nbatch, ndim ] torch
        nsample = y_.shape[0]
        obj = (-y_batch @ z_opt.T).diag() + self.gamma * (t_opt.squeeze() + (1/((1-self.alpha) * nsample))*u_opt.sum(dim=1))
        obj_ro = (-y_batch @ z_ro_opt.T).diag() + self.gamma * (t_ro_opt.squeeze() + (1/((1-self.alpha) * nsample))*u_ro_opt.sum(dim=1)) + lam_batch * torch.linalg.norm(z_ro_opt, ord=1, dim=1)
        
        loss        = obj_ro - obj    # [ nbatch ] torch  
        loss        = loss.numpy()    # [ nbatch ] np 
        assert loss.min() >= -1e-3, f'Violation of loss rule by {loss.min()}. Perhaps increase solver accuracy?'
        return loss
    
    @staticmethod
    def CVaR(alpha, z, y):
        '''
        CVaR computed using scenario approximation. Asymptotically consistent for any distribution.
        Args:
        - alpha:    scalar
        - z:        [ nbatch, nZ ] np
        - y:        [ nbatch, nY ] np
        Return:
        - cvar:     [ nbatch ] np
        '''
        m = len(z)
        out = 1 / np.floor((1 - alpha) * m) * np.inner(z, y).sum(1) # [ nbatch ] np
        return out  # [ nbatch ] np
    
    def prune_feasible_region(self, feasible_region):
        '''
        Safety function
        Prune the provided feasible region points according to the constraint
        Args:
        - feasible_region:  [ nfeasible, nZ ] np, enumeration of the feasible solution set points
        Return:
        - out:              [ < (less than) nfeasible, nZ ] np, pruned points
        '''
        if feasible_region is None:
            return None
        else:
            # condition 1: positivity
            mask1 = np.all(feasible_region >= 0, axis=1)    # [ nfeasible ] np
            mask2 = np.abs(np.inner(feasible_region, np.ones(feasible_region.shape[1])[None, :]) - 1) < 1e-5 # [ nfeasible, 1 ] np
            mask = np.logical_and(mask1, mask2.reshape(-1)) # [ nfeasible ] np
            out = feasible_region[mask] # [ :, nZ ] np
            return out
