from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
import numpy as np
from sklearn.ensemble import AdaBoostClassifier
import time
from sklearn.metrics import accuracy_score
from gurobipy import Env
from src.solvers.solver import Solver
from src.solvers.utils import get_solver
from src.utils.data_handler import DataHandler
from src.utils.results_tracker import ResultTracker
import numpy.typing as npt


class SingleShotTrainer:
    solver: Solver

    def __init__(self, args):
        self.args = args
        self.data_handler = DataHandler(self.args)
        self.data = self.data_handler.get_all_splits()
        self.tracker = ResultTracker(self.args)
        self.solver = get_solver(self.args.solver)
        self.env = Env(params={"LogFile": ""})

    def __del__(self):
        """Ensure proper cleanup of Gurobi environment"""
        self.env.dispose()

    def train(self):
        """Train a fixed Random Forest or Adaboost and reweight trees using LP."""
        # 1. Fit a fixed RF ensemble of 100 trees
        # rf = RandomForestClassifier(
        #     n_estimators=100,
        #     max_depth=self.args.max_depth,
        #     random_state=self.args.seed,
        # )
        #
        # rf.fit(self.data["train"]["x"], self.data["train"]["y"])

        base_learner = DecisionTreeClassifier(
            max_depth=self.args.max_depth, random_state=self.args.seed
        )

        rf = AdaBoostClassifier(
            estimator=base_learner,
            n_estimators=self.args.itermax,
            random_state=self.args.seed,
        )

        rf.fit(self.data["train"]["x"], self.data["train"]["y"])

        y_train_pred = rf.predict(self.data["train"]["x"])
        rf_train_acc = accuracy_score(self.data["train"]["y"], y_train_pred)
        y_test_pred = rf.predict(self.data["test"]["x"])
        rf_test_acc = accuracy_score(self.data["test"]["y"], y_test_pred)

        # 2. Predict with all 100 trees and create prediction matrices
        for tree in rf.estimators_:
            self._create_pred_matrices(tree, self.data)

        # 3. Solve for optimal weights
        y = self.data["train"]["y"]  # Should be in -1/+1 already

        start_time = time.time()

        sample_weights, beta, optim_weights, objval, solvetime = (
            self.solver.solve(self.args, self.data["train"], self.env)
        )

        train_time = time.time() - start_time

        # 5. Store and finalize results
        self.tracker.update_results(
            data=self.data,
            weights=optim_weights,
            computational_time=train_time,
            objval=objval,
            solvetime=solvetime,
            iteration=1,
        )
        print(
            f"Adaboost Train Accuracy: {rf_train_acc:.4f}, Test Accuracy: {rf_test_acc:.4f}\n"
            f"Reweighted Train Accuracy: {self.tracker.results['accuracies']['train'][-1]:.4f}, Test Accuracy: {self.tracker.results['accuracies']['test'][-1]:.4f}"
        )
        self.tracker.finalize_results(weights=optim_weights, data=self.data)

    def _create_pred_matrices(self, learner: object, data: dict):
        """Create correctness matrices for train, validation, and test sets."""
        for split in data:
            x = data[split]["x"]
            pred_matrix = self._create_pred_matrix(learner, x)
            self.data_handler.add_pred(split, pred_matrix)

    def _create_pred_matrix(
        self,
        learner: object,
        x: npt.NDArray[np.float64],
    ):
        """Create a correctness matrix for predictions."""
        x = np.array(x) if not isinstance(x, np.ndarray) else x
        return np.array(learner.predict(x))
