import numpy as np
from gurobipy import GRB, Model

from src.solvers.solver import Solver

"""'
Bi, Jinbo and Zhang, Tong and Bennett, Kristin P.,
Column-generation boosting methods for mixture of kernels,
SIGKDD International Conference on Knowledge Discovery and Data Mining (2004).
https://doi.org/10.1145/1014052.1014113
"""


class CGBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])

        # Create the Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Turn off Gurobi output
            model.Params.TimeLimit = args.gurobi_time_limit
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed

            forest_size = len(pred)
            data_size = len(y_train)

            hyperparam = args.hyperparam or args.F

            # Add variables: classifier weights, slack variables (xi), and margin (rho)
            weights = model.addVars(
                forest_size,
                lb=0.0,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="weights",
            )

            # Slack variables (xi) for each training sample
            xi = model.addVars(
                data_size, lb=0.0, vtype=GRB.CONTINUOUS, name="xi"
            )

            # Add accuracy constraints: sum_j (w_j * h_j(x_i)) + xi_i >= rho
            acc_constraints = []
            for i in range(data_size):
                expr = sum(
                    y_train[i] * pred[j][i] * weights[j]
                    for j in range(forest_size)
                )
                acc_constraints.append(
                    model.addConstr(expr + xi[i] >= 1, name=f"ct_acc_{i}")
                )

            # Set objective: maximize rho (margin) and minimize the slack variables (xi) with a regularization term
            model.setObjective(
                0.5 * sum(weights[j] ** 2 for j in range(forest_size))
                + hyperparam * sum(xi[i] for i in range(data_size)),
                GRB.MINIMIZE,
            )

            # Optimize the model
            model.optimize()

            # Check if an optimal solution is found
            if model.status == GRB.OPTIMAL:
                lp_weights = np.array(
                    [weights[j].x for j in range(forest_size)]
                )
                alpha = np.array(
                    [max(0, acc_constraints[i].Pi) for i in range(data_size)]
                )
                if np.all(alpha == 0):
                    print(
                        "Warning: All sample weights (from the dual) are zero."
                    )
                beta = 0  # no sum of weights constraint in CG-boost
                for j in range(forest_size):
                    print(f"Weight of classifier {j} = {weights[j].x}")
                objVal = model.objVal
                print(f"Obj. value {objVal}")
                solveTime = model.Runtime
                print(model.Runtime)
            else:
                print("No optimal solution found.")

        return alpha, beta, lp_weights, objVal, solveTime
