from gurobipy import Model, GRB
import numpy as np

from src.solvers.solver import Solver


class NMBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])

        # Create the Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Turn off Gurobi output
            model.Params.TimeLimit = args.gurobi_time_limit
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed

            forest_size = len(pred)  # Number of classifiers/trees
            data_size = len(y_train)  # Number of training samples

            hyperparam = args.hyperparam or args.E

            # Add variables: classifier weights, and margin (rho)
            weights = model.addVars(
                forest_size,
                lb=0.0,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="weights",
            )
            # weights = model.addVars(forest_size, lb=1/forest_size, ub=1/forest_size, vtype=GRB.CONTINUOUS, name="weights")

            # Margin variables for each training sample
            # Margin rho_i
            rhoi = model.addVars(
                data_size,
                lb=-GRB.INFINITY,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="rho_i",
            )
            # Negative part of the margin rhoneg_i
            rhonegi = model.addVars(
                data_size,
                lb=-GRB.INFINITY,
                ub=0,
                vtype=GRB.CONTINUOUS,
                name="rho_i",
            )

            # Add accuracy constraints: sum_j (w_j * h_j(x_i)) + xi_i >= rho
            acc_constraints = []
            neg_margin_constraints = []
            for i in range(data_size):
                expr = sum(
                    y_train[i] * pred[j][i] * weights[j]
                    for j in range(forest_size)
                )
                acc_constraints.append(
                    model.addConstr(expr >= rhoi[i], name=f"ct_acc_{i}")
                )
                neg_margin_constraints.append(
                    model.addConstr(
                        rhonegi[i] <= rhoi[i] - (1 / forest_size),
                        name=f"ct_neg_margin_{i}",
                    )
                )

            # Add constraint that the sum of classifier weights must be 1
            wsum_constraint = model.addConstr(
                sum(weights[j] for j in range(forest_size)) == 1.0,
                name="ct_wSum",
            )

            # Set objective: maximize rho (margin) and minimize the slack variables (xi) with a regularization term
            model.setObjective(
                sum(
                    rhonegi[i] + hyperparam * rhoi[i] for i in range(data_size)
                ),
                GRB.MAXIMIZE,
            )

            # Optimize the model
            model.optimize()

            # Check if an optimal solution is found
            if model.status == GRB.OPTIMAL:
                lp_weights = np.array(
                    [weights[j].x for j in range(forest_size)]
                )
                alpha = np.array(
                    [
                        neg_margin_constraints[i].Pi
                        + abs(acc_constraints[i].Pi)
                        for i in range(data_size)
                    ]
                )  # Dual values for accuracy constraints
                if np.all(alpha == 0):
                    print(
                        "Warning: All sample weights (from the dual) are zero."
                    )
                beta = wsum_constraint.Pi
                for j in range(forest_size):
                    print(f"Weight of classifier {j} = {weights[j].x}")
                objVal = model.objVal
                print(f"Obj. value {objVal}")
                solveTime = model.Runtime
                print(model.Runtime)
            else:
                print("No optimal solution found.")

        return alpha, beta, lp_weights, objVal, solveTime
