import numpy as np
import math
from gurobipy import Model, GRB
from src.solvers.solver import Solver

"""
Warmuth, M.K., Glocer, K.A., Vishwanathan, S.V.N. 
Entropy Regularized LPBoost
Lecture Notes in Computer Science, 5254 (2008)
https://doi.org/10.1007/978-3-540-87987-9_23
"""


class ERLPBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])
        data_size = len(y_train)
        forest_size = len(pred)

        # Initialize variables
        dist = np.full(
            data_size, 1 / data_size
        )  # Uniform initial distribution

        # parameter eta
        ln_n_sample = math.log(data_size)
        half_tolerance = args.qp_tolerance / 2.0
        eta = max(0.5, ln_n_sample / half_tolerance)
        max_iter = int(
            max(4.0 / half_tolerance, (8.0 * ln_n_sample / half_tolerance**2))
        )

        hyperparam = args.hyperparam or args.G

        # Initialize gamma values
        gamma_hat = 1.0  # Min observed obj. value

        # Build Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Suppress output
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed
            model.Params.NumericFocus = 3

            # Variables: gamma and distribution
            gamma = model.addVar(
                lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name="gamma"
            )
            dist_vars = model.addVars(
                data_size,
                lb=0.0,
                ub=1.0 / hyperparam,
                vtype=GRB.CONTINUOUS,
                name="dist",
            )

            # Constraints: Sum of distribution variables
            model.addConstr(
                sum(dist_vars[i] for i in range(data_size)) == 1,
                name="sum_is_1",
            )

            total_solve_time = 0.0
            iteration = 0
            while iteration < max_iter:  # loop is for sequential QP
                # Constraints: Margin constraints
                margin_constraints = []
                before = 0
                for j, pred_j in enumerate(pred):
                    margin = sum(
                        dist_vars[i] * y_train[i] * pred_j[i]
                        for i in range(data_size)
                    )
                    margin_constraints.append(
                        model.addConstr(margin <= gamma, name=f"margin_{j}")
                    )
                    before += 1

                print(f"Added {before} constraints in this iteration.")

                # Compute entropy (KL divergence from initial uniform distribution)
                EPSILON = 1e-9
                entropy = sum(
                    dist_vars[i]
                    * (
                        math.log(dist[i] + EPSILON)
                        + (dist_vars[i] - dist[i]) / (dist[i] + EPSILON)
                    )
                    for i in range(data_size)
                )

                model.setObjective(gamma + (entropy / eta), GRB.MINIMIZE)

                # Optimize model
                model.write("lp_old.lp")
                model.optimize()

                # Check for optimality
                if model.status != GRB.OPTIMAL:
                    print("Optimization failed.")
                    model.dispose()
                    return None, None, None, None, None

                # Get the updated distribution
                dist_new = np.array([dist_vars[i].x for i in range(data_size)])
                objval = model.objVal
                total_solve_time += model.Runtime

                # Compute edge values
                edge = [
                    sum(
                        dist_new[i] * (y_train[i] * pred_j[i])
                        for i in range(data_size)
                    )
                    for pred_j in pred
                ]

                # Update gamma_star iteratively
                gamma_star = max(edge) + (float(entropy.getValue()) / eta)

                # Check convergence
                gamma_hat = min(
                    gamma_hat, model.objVal
                )  # Ensure numerical value
                delta_t = gamma_hat - gamma_star  # Ensure numerical stability
                if delta_t <= half_tolerance:
                    break  # Stop if no significant improvement

                # Update distribution and objective value
                dist = dist_new
                iteration += 1

            # Final re-optimization before extracting duals
            model.optimize()

            # Compute alpha (sample weights for the new tree)
            alpha = np.array([dist_vars[i].x for i in range(data_size)])

            # Compute lp_weights (dual variables for classifiers)
            lp_weights = np.array(
                [constraint.Pi for constraint in margin_constraints]
            )
            lp_weights = np.abs(
                lp_weights
            )  # Take absolute value for stability

            # Compute beta (dual variable for sum constraint)
            beta = 0.0

            for j in range(forest_size):
                print(f"Weight of classifier {j} = {lp_weights[j]}")
            print(f"Obj. value {model.objVal}")

        return alpha, beta, lp_weights, objval, total_solve_time
