# from blossom import BlossomClassifier  # ensure to import blossom first
import time
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from src.utils.data_handler import DataHandler
from src.utils.results_tracker import ResultTracker
from sklearn.metrics import accuracy_score


class BenchmarkTrainer:
    def __init__(self, args):
        self.args = args
        self.data_handler = DataHandler(self.args)
        self.data = self.data_handler.get_all_splits()
        self.tracker = ResultTracker(self.args)

    def train(self):
        """Train the selected benchmark model and store results."""
        start_time = time.time()
        model = self._get_model()
        if self.args.benchmark == "xgboost":
            y_train_xgb = np.where(self.data["train"]["y"] == -1, 0, 1)
            model.fit(self.data["train"]["x"], y_train_xgb)
        else:
            if self.args.tree_type == "blossom":
                y = (
                    self.data["train"]["y"] + 1
                ) // 2  # blossom expects 0/1 labels instead of -1/+1
                x = self.data["train"]["x"].values.tolist()
                y = y.tolist()
            else:
                x = self.data["train"]["x"]
                y = self.data["train"]["y"]
            model.fit(x, y)
        train_time = time.time() - start_time

        # Compute accuracy
        train_acc = self._compute_accuracy("train", model)
        test_acc = self._compute_accuracy("test", model)

        # Compute margins and weights
        weights = self._get_weights(model)

        # Store results
        self.tracker.update_results(
            data=self.data,
            weights=weights,
            computational_time=train_time,
            objval=0.0,  # No LP objective, so set to 0
            solvetime=0.0,  # No solving, so set to 0
            iteration=1,
            model=model,
        )

        print(
            f"Benchmark {self.args.benchmark}: Train Acc: {train_acc:.4f}, "
            f"Test Acc: {test_acc:.4f}, Train Time: {train_time:.2f}s"
        )

        # Save final results
        self.tracker.finalize_results(
            weights=weights, data=self.data, model=model
        )

    def _get_model(self):
        """Retrieve the benchmark model based on the specified method."""
        if self.args.benchmark == "adaboost":
            base_learner = (
                BlossomClassifier(
                    max_depth=self.args.max_depth,
                    time=300,
                    minsize=False,
                    mindepth=False,
                    seed=self.args.seed,
                    search=True,
                    preprocessing=False,
                )
                if self.args.tree_type == "blossom"
                else DecisionTreeClassifier(
                    max_depth=self.args.max_depth, random_state=self.args.seed
                )
            )
            model = AdaBoostClassifier(
                estimator=base_learner,
                n_estimators=self.args.itermax,
                learning_rate=self.args.hyperparam or 1.0,
                random_state=self.args.seed,
            )
        elif self.args.benchmark == "xgboost":
            model = XGBClassifier(
                n_estimators=self.args.itermax,
                max_depth=self.args.max_depth,
                random_state=self.args.seed,
                learning_rate=self.args.hyperparam or 1.0,
                verbosity=0,
            )
        elif self.args.benchmark == "lightgbm":
            model = LGBMClassifier(
                n_estimators=self.args.itermax,
                max_depth=self.args.max_depth,
                random_state=self.args.seed,
                learning_rate=self.args.hyperparam or 1.0,
                verbose=-1,
            )
        else:
            raise ValueError(
                f"Invalid benchmark method: {self.args.benchmark}"
            )
        return model

    def _get_weights(self, model):
        """Get tree weights."""
        if self.args.benchmark == "adaboost":
            weights = model.estimator_weights_
        elif self.args.benchmark == "xgboost":
            weights = np.zeros(1)
        elif self.args.benchmark == "lightgbm":
            weights = np.zeros(1)

        return weights

    def _compute_accuracy(self, split, model):
        """Compute accuracy for a dataset split."""
        if self.args.benchmark == "xgboost":
            y_true = np.where(self.data[split]["y"] == -1, 0, 1)
        else:
            if self.args.tree_type == "blossom":
                y_true = (
                    self.data[split]["y"] + 1
                ) // 2  # blossom expects 0/1 labels instead of -1/+1
                x = self.data[split]["x"].values.tolist()
                y_true = y_true.tolist()
            else:
                x = self.data[split]["x"]
                y_true = self.data[split]["y"]
        y_pred = model.predict(x)
        return accuracy_score(y_true, y_pred)
