#!/usr/bin/env python
# coding: utf-8
"""
Abstract optimization model based on GurobiPy
"""


import cvxpy as cp
import gurobipy as gp  # pylint: disable=no-name-in-module
import numpy as np
import torch

from gurobipy import GRB 
from cvxpylayers.torch import CvxpyLayer

from openpto.method.Solvers.abcptoSolver import ptoSolver

class BmatchingSolverCVXPY(ptoSolver):
    """
    Equivalent bipartite matching solver using CVXPY instead of GurobiPy.
    While this implementation produces the same results as the GurobiPy version,
    we **recommend using GurobiPy** for production environments due to:
    1. Significantly faster solve times (especially for large N)
    2. Direct support for integer programming (no need for problem relaxation)
    3. Better numerical stability for complex objectives
    
    Use this CVXPY version only for prototyping or when Gurobi license is unavailable.
    """

    def __init__(self, modelSense=None, isTrain=True, num_nodes=50, **kwargs):
        super().__init__(modelSense)
        self.num_nodes = num_nodes
        self._problem, self.z = self._getProblem(isTrain, num_nodes)

    @property
    def num_vars(self):
        return self.num_nodes * self.num_nodes

    def _getProblem(self, isTrain=True, num_nodes=50):
        n = num_nodes
        z = cp.Variable((n, n), boolean=True)  # Binary decision variables
        
        # Define constraints: row sums and column sums equal 1
        constraints = []
        for i in range(n):
            constraints.append(cp.sum(z[i, :]) == 1)  # Row constraints
        for j in range(n):
            constraints.append(cp.sum(z[:, j]) == 1)  # Column constraints
            
        # The objective will be set dynamically via setObj()
        objective = cp.Maximize(0)  # Placeholder
        
        problem = cp.Problem(objective, constraints)
        return problem, z

    def setObj(self, y):
        n = self.num_nodes
        y_matrix = y.detach().numpy()
        # Construct objective: sum_{i,j} y_ij * z_ij
        objective = cp.Maximize(cp.sum(cp.multiply(y_matrix, self.z)))
        self._problem.objective = objective

    def solve(self, y, **kwargs):
        self.setObj(y)
        # Solve with Gurobi (or other solver if unavailable)
        # Note: Gurobi is faster than CVXPY's default solvers
        self._problem.solve(solver=cp.GUROBI, verbose=False)
        
        n = self.num_nodes
        z_sol = self.z.value if self._problem.status == 'optimal' else np.zeros((n, n))
        obj_val = self._problem.value
        z_tensor = torch.tensor(z_sol.flatten(), dtype=torch.float32)
        return z_tensor, obj_val
    
class BmatchingSolver(ptoSolver):
    """ """

    def __init__(self, modelSense=None, isTrain=True, num_nodes=50, **kwargs):
        super().__init__(modelSense)
        self.num_nodes = num_nodes
        self._model, self.z = self._getModel(isTrain, num_nodes)
        self._model.Params.outputFlag = 0

    @property
    def num_vars(self):
        return self.num_nodes * self.num_nodes

    def _getModel(self, isTrain=True, num_nodes=50):
        n=num_nodes
        m = gp.Model()
        z = m.addVars(n, n, name="z", vtype=GRB.BINARY)
        m.modelSense = GRB.MAXIMIZE
        for i in range(n):
            m.addConstr(gp.quicksum(z[i,j] for j in range(n)) == 1, 
                       name=f"RowSum_{i}")
        for j in range(n):
            m.addConstr(gp.quicksum(z[i,j] for i in range(n)) == 1, 
                       name=f"ColSum_{j}")
        return m, z

       
    def setObj(self, y):
        n=self.num_nodes
        y_matrix = y.detach().numpy()
        obj_expr = gp.quicksum(y_matrix[i,j] * self.z[i,j] 
                          for i in range(n) for j in range(n))
        self._model.setObjective(obj_expr)

    def solve(self, y, **kwargs):
        self.setObj(y)
        self._model.update()
        self._model.optimize()
        n=self.num_nodes
        z_sol = np.array([[self.z[i,j].x for j in range(n)] for i in range(n)])
        obj_val = self._model.objVal
        z_tensor = torch.tensor(z_sol.flatten(), dtype=torch.float32)
        return z_tensor, obj_val
       
