import numpy as np
import gurobipy as gp

from src.solvers.solver import Solver
from src.utils.strings import *


class KnapsackValuesSolver(Solver):

    def __init__(self):
        super().__init__(minimization=False, positive_y=True)

    def _solve_method(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:

        weights = params[WEIGHTS]
        capacity = params[CAPACITY]

        dim = len(y)

        # Create the Gurobi model
        model = gp.Model()

        # Suppress Gurobi output
        model.setParam('OutputFlag', 0)

        objs = model.addMVar(shape=(dim,), vtype=gp.GRB.BINARY, name="objs")

        # Define the model
        model.setObjective(y @ objs, gp.GRB.MAXIMIZE)
        model.addConstr(weights @ objs <= capacity, name="eq")

        # Solve the model
        model.optimize()

        # Sanity check
        assert model.status == gp.GRB.OPTIMAL, "Problem was not solved to optimality"

        return objs.X, model.Runtime

    def compute_metrics(self, y: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> dict:

        metrics = dict()

        total_cost = solution @ y

        metrics[TOTAL_COST] = total_cost
        metrics[SUBOPTIMALITY_COST] = total_cost
        metrics[PENALTY_COST] = 0.0
        metrics[FEASIBLE] = True

        return metrics
