from time import time
import numpy as np
import gurobipy as gp

from src.utils.strings import *
from src.solvers.solver import Solver


class ProductionPlanningSolver(Solver):
    def __init__(self):
        super().__init__(minimization=True, positive_y=True)

    def _solve_method(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:

        start_time = time()

        underproduction_costs = params[UNDERPRODUCTION_COSTS]
        overproduction_costs = params[OVERPRODUCTION_COSTS]
        capacity = params[CAPACITY]

        assert len(underproduction_costs) == len(overproduction_costs) == len(y)
        n_products = len(underproduction_costs)

        # Initialize Gurobi model
        model = gp.Model()
        model.setParam('OutputFlag', 0)

        # Production variables
        productions = model.addVars(n_products, vtype=gp.GRB.INTEGER, lb=0, name='Production')

        # Capacity constraint
        model.addConstr(gp.quicksum(productions[i] for i in range(n_products)) <= capacity, name="CapacityConstraint")

        # Auxiliary variables that represent overproduction and underproduction
        diff_pos = model.addVars(n_products, vtype=gp.GRB.INTEGER, lb=0, name='DifferencePos')
        diff_neg = model.addVars(n_products, vtype=gp.GRB.INTEGER, lb=0, name='DifferenceNeg')
        model.addConstrs((diff_pos[i] >= productions[i] - y[i] for i in range(n_products)), name='DifferenceConstraintPos')
        model.addConstrs((diff_neg[i] >= y[i] - productions[i] for i in range(n_products)), name='DifferenceConstraintNeg')

        # Objective function
        model.setObjective(
            gp.quicksum(diff_pos[i] * diff_pos[i] * overproduction_costs[i] for i in range(n_products)) +
            gp.quicksum(diff_neg[i] * diff_neg[i] * underproduction_costs[i] for i in range(n_products)),
            gp.GRB.MINIMIZE
        )

        # Solve
        model.optimize()
        solution = np.array([variable.X for variable in model.getVars()])
        solution = solution[:n_products]

        runtime = time() - start_time

        return solution, runtime

    def compute_metrics(self, y: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> dict:

        metrics = dict()

        total_cost = self._get_objective_value(y, solution, params)

        metrics[TOTAL_COST] = total_cost
        metrics[SUBOPTIMALITY_COST] = total_cost
        metrics[PENALTY_COST] = 0.0
        metrics[FEASIBLE] = True

        return metrics

    def _get_objective_value(self, demands: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> float:

        assert len(demands) == len(solution)

        underproduction_costs = params[UNDERPRODUCTION_COSTS]
        overproduction_costs = params[OVERPRODUCTION_COSTS]

        assert len(underproduction_costs) == len(overproduction_costs) == len(demands)
        n_products = len(underproduction_costs)

        objective = 0.0

        for i in range(n_products):
            objective += overproduction_costs[i] * (max(0, solution[i] - demands[i])) ** 2
            objective += underproduction_costs[i] * (max(0, demands[i] - solution[i])) ** 2

        return objective
