import numpy as np
from gurobipy import GRB, Model

from src.solvers.solver import Solver

"""'
Shen, Chunhua and Hanxi Li. 
Boosting Through Optimization of Margin Distributions. 
IEEE Transactions on Neural Networks 21 659-666 (2009).
https://doi.org/10.1109/TNN.2010.2040484
"""


class MDBoost(Solver):
    def solve(self, args, data_train, env):
        pred = data_train["pred"]
        y_train = np.array(data_train["y"])

        # Create the Gurobi model
        with Model(env=env) as model:
            model.Params.OutputFlag = 0  # Turn off Gurobi output
            model.Params.TimeLimit = args.gurobi_time_limit
            model.Params.Threads = args.gurobi_num_threads
            model.Params.Seed = args.seed

            forest_size = len(pred)
            data_size = len(y_train)

            hyperparam = args.hyperparam or args.C

            # Add variables: classifier weights
            weights = model.addVars(
                forest_size,
                lb=0.0,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="weights",
            )

            # Margin variables (rho_i) for each training sample
            rho = model.addVars(
                data_size,
                lb=-GRB.INFINITY,
                ub=GRB.INFINITY,
                vtype=GRB.CONTINUOUS,
                name="rho",
            )

            # Add margin constraints: rho_i = y_i * sum_j (w_j * h_j(x_i))
            margin_constraints = []
            for i in range(data_size):
                expr = sum(
                    y_train[i] * pred[j][i] * weights[j]
                    for j in range(forest_size)
                )
                margin_constraints.append(
                    model.addConstr(rho[i] == expr, name=f"ct_margin_{i}")
                )

            # Add constraint that the sum of classifier weights must be equal to D
            wsum_constraint = model.addConstr(
                sum(weights[j] for j in range(forest_size)) == hyperparam,
                name="ct_wSum",
            )

            if args.use_identity_approx:
                # Use identity matrix approximation for A
                print("Using identity matrix approximation.")
                quadratic_term = 0.5 * sum(
                    rho[i] * rho[i] for i in range(data_size)
                )
            else:
                # Create the matrix A for variance minimization (with off-diagonal terms)
                A = np.ones((data_size, data_size)) * (-1 / (data_size - 1))
                np.fill_diagonal(A, 1)

                # Convert the rho variables into a NumPy array
                rho_values = np.array([rho[i] for i in range(data_size)])

                quadratic_term = 0.5 * np.dot(
                    rho_values.T, np.dot(A, rho_values)
                )

            # Set objective: maximize 1^T rho - (1/2) * rho^T A rho
            linear_term = sum(rho[i] for i in range(data_size))  # 1^T rho

            # Final objective to maximize
            model.setObjective(linear_term - quadratic_term, GRB.MAXIMIZE)

            # Optimize the model
            model.optimize()

            # Check if an optimal solution is found
            if model.status == GRB.OPTIMAL:
                lp_weights = np.array(
                    [weights[j].x for j in range(forest_size)]
                )
                alpha = np.array(
                    [
                        max(0, margin_constraints[i].Pi)
                        for i in range(data_size)
                    ]
                )  # Dual values for accuracy constraints
                if np.all(alpha == 0):
                    print(
                        "Warning: All sample weights (from the dual) are zero."
                    )
                beta = (
                    wsum_constraint.Pi
                )  # Dual value for the sum of weights constraint
                for j in range(forest_size):
                    print(f"Weight of classifier {j} = {weights[j].x}")
                objVal = model.objVal
                print(f"Obj. value {objVal}")
                solveTime = model.Runtime
                print(model.Runtime)
            else:
                print("No optimal solution found.")

        return alpha, beta, lp_weights, objVal, solveTime
