import numpy as np
import gurobipy as gp

from src.solvers.solver import Solver
from src.utils.strings import *


class WeightedSetMultiCoverSolver(Solver):

    def __init__(self):
        super().__init__(minimization=True, positive_y=True)

    def _solve_method(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:

        cover_matrix = params[COVER_MATRIX]
        sets_costs = params[SETS_COSTS]
        items_costs = params[ITEMS_COSTS]

        n_items = len(items_costs)
        n_sets = len(sets_costs)

        assert cover_matrix.shape == (n_items, n_sets)

        # Create the Gurobi model
        env = gp.Env(empty=True)
        env.setParam("OutputFlag", 0)
        env.start()
        model = gp.Model(env=env)

        # This is the set of decision variables
        decision_vars = np.array([model.addVar(vtype=gp.GRB.INTEGER, lb=0, name=f'x_{j}') for j in range(n_sets)])

        # This is the coverage requirements constraint matrix
        cover_constraint = cover_matrix @ decision_vars

        # Initialize the sets of indicator and slack variables
        all_indicator_vars = list()
        all_slack_vars = list()

        # Add the indicator constraints
        for i in range(0, n_items):
            indicator_var = model.addVar(vtype=gp.GRB.BINARY, name=f'z_{i}')
            slack_var = model.addVar(vtype=gp.GRB.INTEGER, name=f's_{i}')

            # Add the indicator and slack variables
            all_indicator_vars.append(indicator_var)
            all_slack_vars.append(slack_var)

            # LHS of the indicator constraint
            lhs_constraint = slack_var + cover_constraint[i]

            # Indicator constraint
            model.addGenConstrIndicator(binvar=indicator_var,
                                        binval=True,
                                        lhs=lhs_constraint,
                                        sense=gp.GRB.GREATER_EQUAL,
                                        rhs=y[i],
                                        name=f'Indicator_constraint_{i}')

            # Add demands satisfaction constraint
            model.addConstr(cover_constraint[i] >= y[i] * (1 - indicator_var))

        # Add the penalty when the indicator constraint is violated
        all_slack_vars = gp.MVar(all_slack_vars)
        penalty = items_costs @ all_slack_vars

        # Convert the list of decision variables in matrix form
        decision_vars = gp.MVar(decision_vars)

        # Objective function
        obj = sets_costs @ decision_vars + penalty
        model.setObjective(obj, gp.GRB.MINIMIZE)

        # Solve the model
        model.optimize()
        status = model.status
        assert status == gp.GRB.Status.OPTIMAL, "Solution is not optimal"

        # Get the solution
        solution = [s.X for s in decision_vars]
        solution = np.array(solution, dtype=np.float32)

        return solution, model.Runtime

    def compute_metrics(self, y: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> dict:

        metrics = dict()

        cover_matrix = params[COVER_MATRIX]
        sets_costs = params[SETS_COSTS]
        items_costs = params[ITEMS_COSTS]

        n_items = len(items_costs)
        n_sets = len(sets_costs)

        assert cover_matrix.shape == (n_items, n_sets)

        sub_cost = sets_costs @ solution
        uncovered_demands = np.clip(y - cover_matrix @ solution, a_min=0, a_max=None)
        feasible = not (uncovered_demands > 0).any()
        penalty_cost = items_costs @ uncovered_demands

        metrics[TOTAL_COST] = sub_cost + penalty_cost
        metrics[SUBOPTIMALITY_COST] = sub_cost
        metrics[PENALTY_COST] = penalty_cost
        metrics[FEASIBLE] = feasible

        return metrics
