# from blossom import BlossomClassifier  # ensure to import blossom first

from time import time

import numpy as np
import numpy.typing as npt
from sklearn.tree import DecisionTreeClassifier
from gurobipy import Env

from src.solvers.solver import Solver
from src.solvers.utils import get_solver
from src.utils.data_handler import DataHandler
from src.utils.results_tracker import ResultTracker


class ColumnGenerator:
    solver: Solver

    def __init__(self, args):
        self.args = args
        self.solver = get_solver(self.args.solver)
        self.data_handler = DataHandler(self.args)
        self.data = self.data_handler.get_all_splits()
        self.tracker = ResultTracker(self.args)
        self.env = Env(params={"LogFile": ""})

    def __del__(self):
        """Ensure proper cleanup of Gurobi environment"""
        self.env.dispose()

    def train(self):
        prev_obj = float("inf")
        iters = 0
        beta = 1e5

        while iters < self.args.itermax:
            if iters == 0:
                sample_weights = [1.0 / len(self.data["train"]["y"])] * len(
                    self.data["train"]["y"]
                )

            new_learner = self._train_learner(
                self.data["train"]["x"],
                self.data["train"]["y"],
                sample_weights,
                self.args.tree_type,
            )
            self._create_pred_matrices(new_learner, self.data)
            dual_sum = np.dot(
                sample_weights * self.data["train"]["y"],
                self.data["train"]["pred"][-1],
            )
            print(f"Dual sum = {dual_sum}, beta = {beta}")

            if dual_sum <= 0.0 and self.args.check_dual_const:
                print("Dual constraint not satisfied, terminating")
                break

            start_time = time()

            # Solve the LP model and extract dual variables and sample weights
            sample_weights, beta, optim_weights, objval, solvetime = (
                self.solver.solve(self.args, self.data["train"], self.env)
            )

            if sample_weights is None or beta is None:
                print(
                    "Model did not solve optimally. Exiting column generation."
                )
                break
            curr_com_time = time() - start_time

            self.tracker.update_results(
                data=self.data,
                weights=optim_weights,
                computational_time=curr_com_time,
                objval=objval,
                solvetime=solvetime,
                iteration=iters + 1,
            )
            print(
                f"Iteration {iters + 1}, Train Accuracy: {self.tracker.results['accuracies']['train'][-1]:.4f}, Test Accuracy: {self.tracker.results['accuracies']['test'][-1]:.4f}"
            )

            z_diff = prev_obj - objval
            if (
                iters + 1 % self.args.obj_check == 0
                and z_diff <= self.args.obj_eps
            ):
                print(
                    f"Stopping criterion met at iteration {iters + 1}: z_diff = {z_diff:.4f}, threshold = {self.args.obj_eps}"
                )
                break
            prev_obj = objval

            if (
                iters % self.args.checkpoint == 0
                or iters == self.args.itermax - 1
            ):
                self.tracker.finalize_results(
                    weights=optim_weights, data=self.data
                )
            iters += 1

    def _train_learner(
        self,
        x: npt.NDArray[np.float64],
        y: npt.NDArray[np.float64],
        sample_weights: npt.NDArray[np.float64],
        tree_type: str = "CART",
    ):
        """Train weak tree learner."""
        if tree_type == "CART":
            # Train regression tree using CART algorithm
            learner = DecisionTreeClassifier(max_depth=self.args.max_depth)
        elif tree_type == "blossom":
            learner = BlossomClassifier(
                max_depth=self.args.max_depth,
                time=300,
                minsize=False,
                mindepth=False,
                seed=self.args.seed,
                search=True,
                preprocessing=False,
            )
            y = (y + 1) // 2  # blossom expects 0/1 labels instead of -1/+1
            x = x.values.tolist()
            y = y.tolist()
        else:
            msg = f"Invalid tree type: {tree_type}"
            raise ValueError(msg)
        return learner.fit(x, y, sample_weight=sample_weights)

    def _create_pred_matrices(self, learner: object, data: dict):
        """Create correctness matrices for train, validation, and test sets."""
        for split in data:
            x = data[split]["x"]
            pred_matrix = self._create_pred_matrix(learner, x)
            self.data_handler.add_pred(split, pred_matrix)

    def _create_pred_matrix(
        self,
        learner: object,
        x: npt.NDArray[np.float64],
    ):
        """Create a correctness matrix for predictions."""
        x = np.array(x) if not isinstance(x, np.ndarray) else x
        if self.args.crb:
            predictions = np.array(self._get_crb_pred(learner, x))
        else:
            predictions = np.array(learner.predict(x))
        return predictions

    def _get_crb_pred(self, learner, x):
        proba_train = learner.predict_proba(x)
        return proba_train[:, 1] - proba_train[:, 0]
