#!/usr/bin/env python
# coding: utf-8
"""
Abstract optimization model based on GurobiPy
"""
import gurobipy as gp  # pylint: disable=no-name-in-module
import numpy as np
import torch

from gurobipy import GRB 

import cvxpy as cp

from cvxpylayers.torch import CvxpyLayer

from openpto.method.Solvers.abcptoSolver import ptoSolver


class CpPortfolioSolver(ptoSolver):
    """ """

    def __init__(self, num_stocks, modelSense=None, alpha=0.1, **kwargs):
        super().__init__(modelSense)
        print("alpha: ", alpha)
        self.num_stocks = num_stocks
        self.solver = self._create_cvxpy_problem(alpha)

    @property
    def num_vars(self):
        return self.num_stocks

    def _create_cvxpy_problem(
        self,
        alpha,
    ):
        x_var = cp.Variable(self.num_stocks)
        L_sqrt_para = cp.Parameter((self.num_stocks, self.num_stocks))
        p_para = cp.Parameter(self.num_stocks)
        constraints = [x_var >= 0, x_var <= 1, cp.sum(x_var) == 1]
        objective = cp.Maximize(
            p_para.T @ x_var - alpha * cp.sum_squares(L_sqrt_para @ x_var)
        )
        problem = cp.Problem(objective, constraints)
        #print("or",p_para.shape,L_sqrt_para.shape)
        return CvxpyLayer(problem, parameters=[p_para, L_sqrt_para], variables=[x_var])

    def solve(self, Y, sqrt_covar):
        #print("inp",Y.shape,sqrt_covar.shape)
        sols = self.solver(Y, sqrt_covar)
        return sols
    
class PortfolioSolver(ptoSolver):
    """ """

    def __init__(self, num_stocks, modelSense=None, alpha=0.1, **kwargs):
        super().__init__(modelSense)
        self.num_stocks = num_stocks
        self._model, self.z = self._getModel()
        #self._model.Params.NonConvex = 1
        self.alpha = alpha
        self._model.Params.outputFlag = 0


    def _getModel(
        self):
        n = self.num_stocks
        model = gp.Model("MarkowitzOptimization")
        
        z = model.addMVar((n,),lb=0, ub=1, vtype=gp.GRB.CONTINUOUS, name="z")
        model.addConstr(z.sum() == 1, name="BudgetConstraint")
        
        return model, z
    def setObj(self, y,Q):
        y_numpy = y.detach().numpy()
        Q_numpy = Q.detach().numpy()
        #print(y_numpy.shape,Q_numpy.shape)
        linear_term = y_numpy @ self.z  
        
        # 构建二次项：self.alpha * (self.z @ Q @ self.z)，使用 quad_form 函数
        quadratic_term = self.alpha * (self.z @ Q_numpy @ self.z) 
        
        # 构建目标函数：max (linear_term - quadratic_term)
        obj_expr = linear_term - quadratic_term
        self._model.setObjective(obj_expr, gp.GRB.MAXIMIZE)
       
    def solve(self, Y, sqrt_covar):
        #print("inp",Y.shape,sqrt_covar.shape)
        self.setObj(Y, sqrt_covar)
        self._model.update()
        print("begin")
        self._model.optimize()
        print("end")
        n=self.num_stocks
        z_sol = np.array([self.z[i].x for i in range(n)])
        z_tensor = torch.tensor(z_sol, dtype=torch.float32)
        return z_tensor
