import numpy as np
import math
from gurobipy import Model, GRB
from src.solvers.solver import Solver


class QRLPBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])
        data_size = len(y_train)
        forest_size = len(pred)

        # Initialize variables
        dist = np.full(
            data_size, 1 / data_size
        )  # Uniform initial distribution
        weights = np.zeros(forest_size)  # Weights on hypotheses
        gamma = float("inf")  # Initial gamma

        # Parameter eta
        ln_n_sample = math.log(data_size)
        half_tolerance = args.qp_tolerance / 2.0
        eta = max(0.5, ln_n_sample / half_tolerance)

        hyperparam = args.hyperparam or args.nu

        # Initialize Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Suppress Gurobi output
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed

            # Variables: gamma and distribution
            gamma_var = model.addVar(
                lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name="gamma"
            )
            dist_vars = model.addVars(
                data_size,
                lb=0.0,
                ub=1.0 / hyperparam,
                vtype=GRB.CONTINUOUS,
                name="dist",
            )

            # Constraints: Sum of distribution variables
            sum_constraint = model.addConstr(
                sum(dist_vars[i] for i in range(data_size)) == 1,
                name="sum_is_1",
            )

            # Margin constraints
            margin_constraints = []
            for j in range(forest_size):
                margin = sum(
                    dist_vars[i] * y_train[i] * pred[j][i]
                    for i in range(data_size)
                )
                margin_constraints.append(
                    model.addConstr(margin <= gamma_var, name=f"margin_{j}")
                )

            total_solve_time = 0.0
            # Main loop
            while True:
                # Objective: Regularized edge minimization
                reg_term = sum(
                    (np.log(dist[i]) if dist[i] > 0 else 0) * dist_vars[i]
                    + (dist_vars[i] * dist_vars[i] / (2 * dist[i]))
                    for i in range(data_size)
                )
                model.setObjective(
                    gamma_var + (1 / eta) * reg_term, GRB.MINIMIZE
                )

                # Optimize model
                model.optimize()

                # Check for optimality
                if model.status != GRB.OPTIMAL:
                    print("Optimization failed.")
                    model.dispose()
                    return None, None, None, None, None

                # Update distribution
                dist_new = np.array([dist_vars[i].x for i in range(data_size)])
                objval = model.objVal
                total_solve_time += model.Runtime

                # Check convergence
                if (
                    np.any(dist_new <= 0)
                    or abs(gamma - objval) < args.qp_tolerance
                ):
                    break

                dist = dist_new
                gamma = objval

                # Update weights (dual variables of margin constraints)
                weights = np.array(
                    [abs(constraint.Pi) for constraint in margin_constraints]
                )

            # Compute alpha (sample weights for the new tree)
            alpha = np.array([dist_vars[i].x for i in range(data_size)])

            # Compute beta (dual variable for sum constraint)
            beta = sum_constraint.Pi

            for j in range(forest_size):
                print(f"Weight of classifier {j} = {weights[j]}")
            print(f"Obj. value {model.objVal}")

        return alpha, beta, weights, objval, total_solve_time
