import json
import os
import numpy as np
from sklearn.metrics import accuracy_score


class ResultTracker:
    def __init__(self, args):
        """
        Initialize the ResultTracker to store results for each iteration and final results.

        Args:
            args: Parsed arguments containing metadata about the experiment.
        """
        self.args = args

        # Iterative results (tracked during iterations)
        self.results = {
            "accuracies": {"train": [], "val": [], "test": []},
            "computational_times": [],
            "objvals": [],
            "solvetimes": [],
            "iterations": [],
        }

        # Final results (computed after all iterations)
        self.final_results = {
            "tree_weights": [],
            "margins": {"train": [], "val": [], "test": []},
        }

    def update_results(
        self,
        data,
        weights,
        computational_time,
        objval,
        solvetime,
        iteration,
        model=None,
    ):
        """
        Update results for all splits (train, val, test) using weighted decisions.
        If running benchmarks, it collects accuracy per iteration.
        """
        if self.args.run_benchmark:
            # Compute accuracy per iteration for the benchmark model
            train_acc_per_iter, val_acc_per_iter, test_acc_per_iter = (
                self.get_benchmark_performance_per_iteration(
                    model,
                    data["train"]["x"],
                    data["train"]["y"],
                    data["val"]["x"],
                    data["val"]["y"],
                    data["test"]["x"],
                    data["test"]["y"],
                    method=self.args.benchmark,
                )
            )

            for i in range(len(train_acc_per_iter)):
                self.results["accuracies"]["train"].append(
                    train_acc_per_iter[i]
                )
                self.results["accuracies"]["val"].append(val_acc_per_iter[i])
                self.results["accuracies"]["test"].append(test_acc_per_iter[i])
                self.results["iterations"].append(iteration + i)

        else:
            for split in ["train", "val", "test"]:
                y_true = data[split]["y"]
                pred = np.vstack(data[split]["pred"])
                accuracy = self._calculate_metrics(pred, weights, y_true)
                self.results["accuracies"][split].append(accuracy)

            self.results["iterations"].append(iteration)

        # Store computational time, objective value, and solve time (for consistency)
        self.results["computational_times"].append(computational_time)
        self.results["objvals"].append(objval)
        self.results["solvetimes"].append(solvetime)

    def _calculate_metrics(self, pred, weights, y_true):
        """
        Calculate accuracy based on the correctness matrix and weights.

        """
        y_true = (
            np.array(y_true) if not isinstance(y_true, np.ndarray) else y_true
        )
        # Aggregate weighted pred
        aggregated_predictions = np.dot(weights, pred)
        # Correct predictions based on aggregated predictions
        predictions = np.where(aggregated_predictions > 0, 1, -1)
        y_pred_correct = np.where(predictions == y_true, 1, 0)
        return np.sum(y_pred_correct) / len(y_true)

    def get_benchmark_performance_per_iteration(
        self, model, X_train, y_train, X_val, y_val, X_test, y_test, method
    ):
        """
        Compute boosting model performance per iteration, considering the ensemble prediction at each step.
        """
        train_acc, val_acc, test_acc = [], [], []

        if method == "adaboost":
            # Initialize cumulative margins
            train_margins = np.zeros(len(X_train))
            val_margins = np.zeros(len(X_val))
            test_margins = np.zeros(len(X_test))

            # Iterate over weak learners and accumulate weighted predictions
            for t in range(len(model.estimators_)):
                weak_learner = model.estimators_[t]
                weight = model.estimator_weights_[t]  # Alpha_t

                # Update margins for each dataset
                for X, y, margins, acc_list in zip(
                    [X_train, X_val, X_test],
                    [y_train, y_val, y_test],
                    [train_margins, val_margins, test_margins],
                    [train_acc, val_acc, test_acc],
                ):
                    # Get current tree predictions and convert to {-1,1}
                    if self.args.tree_type == "blossom":
                        X = X.values.tolist()
                    pred = weak_learner.predict(X)
                    if self.args.tree_type == "blossom":
                        pred = 2 * pred - 1
                    margins += weight * pred
                    final_pred = np.where(margins > 0, 1, -1)
                    acc_list.append(accuracy_score(y, final_pred))

        elif method == "xgboost":  # XGBoost or LightGBM
            for t in range(
                1, model.n_estimators + 1
            ):  # Start from iteration 1
                for X, y, acc_list in zip(
                    [X_train, X_val, X_test],
                    [y_train, y_val, y_test],
                    [train_acc, val_acc, test_acc],
                ):
                    pred = model.predict(
                        X, iteration_range=(0, t), output_margin=True
                    )
                    acc_list.append(
                        accuracy_score(y, np.where(pred > 0, 1, -1))
                    )
        else:  # lightgbm
            train_preds = np.zeros(len(X_train))
            val_preds = np.zeros(len(X_val))
            test_preds = np.zeros(len(X_test))

            for t in range(1, model.n_estimators_ + 1):
                for X, y, preds, acc_list in zip(
                    [X_train, X_val, X_test],
                    [y_train, y_val, y_test],
                    [train_preds, val_preds, test_preds],
                    [train_acc, val_acc, test_acc],
                ):
                    preds += model.predict(X, num_iteration=t, raw_score=True)
                    acc_list.append(
                        accuracy_score(y, np.where(preds > 0, 1, -1))
                    )

        return train_acc, val_acc, test_acc

    def _calculate_margins(self, pred, weights, data):
        """
        Calculate margins for train, val, and test splits and store them in final results.
        """
        for split in ["train", "val", "test"]:
            # Extract true labels for the current split
            y = np.array(data[split]["y"])

            # Aggregate weighted decisions
            # Stack decision matrices and compute dot product with weights
            try:
                if self.args.tree_type == "blossom":
                    pred_split = (
                        2 * np.array(pred[split]) - 1
                    )  # DL8.5 outputs 0/1 labels instead of -1/+1
                else:
                    pred_split = pred[split]
                aggregated_decisions = np.dot(weights, pred_split)
            except:
                return

            # Calculate margins as product of aggregated decisions and true labels
            margins = y * aggregated_decisions

            # Store margins in final results
            self.final_results["margins"][split] = margins.tolist()

    def calculate_xgboost_margins(self, model, data):
        for split in ["train", "val", "test"]:
            # Extract true labels for the current split
            y = np.array(data[split]["y"])
            X = data[split]["x"]

            raw_margins = model.predict(X, output_margin=True)
            margins = y * raw_margins

            # Store margins in final results
            self.final_results["margins"][split] = margins.tolist()

    def calculate_lightgbm_margins(self, model, data):
        for split in ["train", "val", "test"]:
            # Extract true labels for the current split
            y = np.array(data[split]["y"])
            X = data[split]["x"]

            raw_margins = model.predict(X, raw_score=True)
            margins = y * raw_margins
            # Store margins in final results
            self.final_results["margins"][split] = margins.tolist()

    def prepare_input(self, x, tree_type):
        if tree_type == "blossom":
            return x.values.tolist()
        return x

    def _get_benchmark_pred(self, model, data):
        """
        Compute tree weights and margins for the selected benchmark model.
        """

        if self.args.benchmark == "adaboost":
            pred = {
                split: np.vstack(
                    [
                        est.predict(
                            self.prepare_input(
                                data[split]["x"], self.args.tree_type
                            )
                        )
                        for est in model.estimators_
                    ]
                )
                for split in ["train", "val", "test"]
            }

        else:  # XGBoost or LightGBM
            predict_params = (
                {"output_margin": True}
                if self.args.benchmark == "xgboost"
                else {"raw_score": True}
            )
            pred = {
                "train": model.predict(data["train"]["x"], **predict_params),
                "val": model.predict(data["val"]["x"], **predict_params),
                "test": model.predict(data["test"]["x"], **predict_params),
            }

        return pred

    def update_final(self, tree_weights, margins):
        """
        Update final results with data computed after all iterations.

        """
        self.final_results["tree_weights"] = tree_weights.tolist()
        self.final_results["margins"]["train"] = margins["train"]
        self.final_results["margins"]["val"] = margins["val"]
        self.final_results["margins"]["test"] = margins["test"]

    def save_to_json(self, directory="Results"):
        """
        Save all results to a JSON file.

        """
        # Combine iterative and final results
        results = {
            **vars(self.args),
            **self.results,
            **self.final_results,
        }

        # Create the directory if it doesn't exist
        os.makedirs(directory, exist_ok=True)

        # Build the filename
        if self.args.hyperparam is not None:
            hyperparam_value = self.args.hyperparam
        else:
            METHOD_HYPERPARAMS = {
                "cg_boost": "F",
                "erlp_boost": "G",
                "lp_boost": "D",
                "md_boost": "C",
                "qrlp_boost": "nu",
                "neg_margins": "E",
            }
            # Retrieve the method-specific hyperparameter dynamically
            method_hyperparam = METHOD_HYPERPARAMS.get(self.args.solver, None)
            if method_hyperparam is not None and hasattr(
                self.args, method_hyperparam
            ):
                hyperparam_value = getattr(self.args, method_hyperparam)
            else:
                raise ValueError(
                    f"Unknown method or missing hyperparameter for {self.args.solver}"
                )

        if self.args.run_benchmark:
            solver = self.args.benchmark
        else:
            solver = self.args.solver
        filename = os.path.join(
            directory,
            f"{self.args.dataset}_{solver}_depth{self.args.max_depth}_hyperparam{hyperparam_value}_treetype_{self.args.tree_type}_CRB{self.args.crb}_seed{self.args.seed}_singleshot{self.args.run_single_shot}.json",
        )

        # Save the results to the file
        with open(filename, "w") as f:
            json.dump(results, f, indent=4)

        print(f"Results saved to {filename}")

    def finalize_results(self, weights, data, model=None):
        """
        Finalize results by calculating margins, updating final results, and saving to JSON.

        """
        if self.args.run_benchmark:
            if self.args.benchmark == "adaboost":
                pred = self._get_benchmark_pred(model=model, data=data)
                self._calculate_margins(pred=pred, weights=weights, data=data)
            elif self.args.benchmark == "xgboost":
                self.calculate_xgboost_margins(model, data)
            elif self.args.benchmark == "lightgbm":
                self.calculate_lightgbm_margins(model, data)
            solver = self.args.benchmark
        else:
            pred = {
                "train": np.vstack(data["train"]["pred"]),
                "val": np.vstack(data["val"]["pred"]),
                "test": np.vstack(data["test"]["pred"]),
            }
            solver = self.args.solver

            self._calculate_margins(pred=pred, weights=weights, data=data)

        self.update_final(
            tree_weights=weights, margins=self.final_results["margins"]
        )
        dir = f"Results/{self.args.dataset}/{solver}"
        self.save_to_json(directory=dir)
