import numpy as np
import gurobipy as gp

from src.solvers.solver import Solver
from src.utils.strings import *


class KnapsackWeightsSolver(Solver):

    def __init__(self):
        super().__init__(minimization=False, positive_y=True)

    def _solve_method(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:

        values = params[VALUES]
        capacity = params[CAPACITY]

        dim = len(y)

        # Create the Gurobi model
        model = gp.Model()

        # Suppress Gurobi output
        model.setParam('OutputFlag', 0)

        objs = model.addMVar(shape=(dim,), vtype=gp.GRB.BINARY, name="objs")

        # Define the model
        model.setObjective(values @ objs, gp.GRB.MAXIMIZE)
        model.addConstr(y @ objs <= capacity, name="eq")

        # Solve the model
        model.optimize()

        # Sanity check
        assert model.status == gp.GRB.OPTIMAL, "Problem was not solved to optimality"

        return objs.X, model.Runtime

    def compute_metrics(self, y: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> dict:

        metrics = dict()

        capacity = float(params[CAPACITY])
        penalty = float(params[PENALTY])
        values = params[VALUES]

        feasible = solution @ y <= capacity

        # Recourse action
        if not feasible:
            sub_cost, penalty_cost = self._recourse_action(y, solution, values, capacity, penalty)
        else:
            sub_cost = solution @ values
            penalty_cost = 0.0

        metrics[TOTAL_COST] = sub_cost + penalty_cost
        metrics[SUBOPTIMALITY_COST] = sub_cost
        metrics[PENALTY_COST] = penalty_cost
        metrics[FEASIBLE] = feasible

        return metrics

    def _recourse_action(self, weights: np.ndarray, solution: np.ndarray, values: np.ndarray,
                         capacity: float, penalty: float) -> tuple[float, float]:

        dim = len(weights)

        # Create the Gurobi model
        model = gp.Model()

        # Suppress Gurobi output
        model.setParam('OutputFlag', 0)

        # Second-stage decisions
        # Selected items
        u_plus = model.addMVar(shape=(dim,), vtype=gp.GRB.BINARY, name="u_plus")
        # Removed items
        u_minus = model.addMVar(shape=(dim,), vtype=gp.GRB.BINARY, name="u_minus")

        # Second-stage constraints
        model.addConstr(weights @ solution + weights @ u_plus - weights @ u_minus <= capacity, name="eq")

        # We can only remove already selected items
        model.addConstr(solution >= u_minus, name="removal cons")

        # We can only add items that have not been selected during first-stage
        model.addConstr(solution + u_plus <= 1, name="additive cons")

        # Define the objective function
        first_stage_cost = values @ solution
        second_stage_cost = 1 / penalty * values @ u_plus - penalty * values @ u_minus
        objective = first_stage_cost + second_stage_cost
        model.setObjective(objective, gp.GRB.MAXIMIZE)

        # Solve the model
        model.optimize()

        # Sanity check
        assert model.status == gp.GRB.OPTIMAL, "Problem was not solved to optimality"

        sub_cost = float(values @ solution)
        penalty_cost = float(1 / penalty * values @ u_plus.X - penalty * values @ u_minus.X)

        return sub_cost, penalty_cost
