import numpy as np
from gurobipy import GRB, Model

from src.solvers.solver import Solver


"""'
Demiriz, A., Bennett, K.P. & Shawe-Taylor, J.
Linear Programming Boosting via Column Generation.
Machine Learning 46, 225–254 (2002).
https://doi.org/10.1023/A:1012470815092

Note that there are multiple sources that call their approach 'LP-Boosting', we
implement the soft-margin variant as defined in Demiriz et al. (2002)
Section 3.1, formulation (2)
"""


class LPBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])

        # Create the Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Turn off Gurobi output
            model.Params.TimeLimit = args.gurobi_time_limit
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed

            forest_size = len(pred)
            data_size = len(y_train)

            hyperparam = args.hyperparam or args.D

            # Add variables: classifier weights, slack variables (xi), and margin (rho)
            weights = model.addVars(
                forest_size,
                lb=0.0,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="weights",
            )

            # Slack variables (xi) for each training sample
            xi = model.addVars(
                data_size, lb=0.0, vtype=GRB.CONTINUOUS, name="xi"
            )

            # Add accuracy constraints: sum_j (w_j * h_j(x_i)) + xi_i >= rho
            acc_constraints = []
            for i in range(data_size):
                expr = sum(
                    y_train[i] * pred[j][i] * weights[j]
                    for j in range(forest_size)
                )
                acc_constraints.append(
                    model.addConstr(expr + xi[i] >= 1, name=f"ct_acc_{i}")
                )

            model.setObjective(
                sum(weights[j] for j in range(forest_size))
                + hyperparam * sum(xi[i] for i in range(data_size)),
                GRB.MINIMIZE,
            )

            # Optimize the model
            model.optimize()

            if model.status == GRB.OPTIMAL:
                lp_weights = np.array(
                    [weights[j].x for j in range(forest_size)]
                )
                alpha = np.array(
                    [acc_constraints[i].Pi for i in range(data_size)]
                )  # Dual values for accuracy constraints
                if np.all(alpha == 0):
                    print(
                        "Warning: All sample weights (from the dual) are zero."
                    )
                beta = 0.0
                for j in range(forest_size):
                    print(f"Weight of classifier {j} = {weights[j].x}")
                objVal = model.objVal
                print(f"Obj. value {objVal}")
                solveTime = model.Runtime
                print(model.Runtime)
            else:
                print("No optimal solution found.")

        return alpha, beta, lp_weights, objVal, solveTime
