"""
This module provides a wrapper around the TREEFARMS model for easier interaction and evaluation.
"""
import logging
from collections import defaultdict

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score

from treefarms import TREEFARMS
from gosdt import GOSDTClassifier

from .tree_classifier import Tree


class RsetWrapper:
    """ ResetWrapper class to interact with the TREEFARMS model """
    def __init__(self, config):
        """ Initialize the RsetWrapper with the given configuration.

        Args:
            config (dict): Configuration for the TREEFARMS model.
        """
        self.config = config
        self.model = TREEFARMS(config)
        self.optimal = GOSDTClassifier()
        self.best_tree = defaultdict(None)
        self.best_tree_score = defaultdict(None)

        self.special_tree = {}
        self.n_features_ = None
        self.classes_ = None

    def fit(self, X, y):
        """ Fit the TREEFARMS model to the data.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
        """
        self.model.fit(pd.DataFrame(X), pd.Series(y))
        self.best_tree = {}
        self.n_features_ = X.shape[1]
        self.classes_ = np.unique(y)
    
    def predict(self, X, idx=None):
        """ Predict using the model at the given index.

        Args:
            X (np.ndarray): Feature matrix.
            idx (int, optional): Index of the model to use for prediction. Defaults to None.

        Returns:
            np.ndarray: Predicted labels.
        """
        if idx is None:
            model = self.get_optimal_tree()
        else:
            model = self.get_tree(idx)
        return model.predict(X)

    def predict_all(self, X):
        """ Generate predictions for all trees in the model.
        This will take a while if the number of trees is large.

        Args:
            X (np.ndarray): Feature matrix.

        Returns:
            np.ndarray: Array of predictions from all trees.
        """
        logger = logging.getLogger("RsetWrapper.predict_all")
        logger.info("Predicting %s trees", self.ntrees())
        if self.ntrees() > 100_000:
            logger.warning("Number of trees is large, this may take a while.")
        predictions = []
        for i in range(self.ntrees()):
            model = self.get_tree(i)
            pred = model.predict(X)
            predictions.append(pred)
        return np.array(predictions)
    
    def predict_many(self, X, indices):
        """

        """
        logger = logging.getLogger("RsetWrapper.predict_many")
        logger.info("Predicting %s trees", len(indices))
        if len(indices) > 100_000:
            logger.warning("Number of trees is large, this may take a while.")
        predictions = []
        for i in indices:
            model = self.get_tree(i)
            pred = model.predict(X)
            predictions.append(pred)
        return np.array(predictions)
    
    def score_many(self, X, y, indices):
        """ Calculate the scores of the models at the given indices.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            indices (list): List of indices of the models to evaluate.

        Returns:
            np.ndarray: Array of scores for each model.
        """
        logger = logging.getLogger("RsetWrapper.score_many")
        logger.info("Scoring %s trees", len(indices))
        if len(indices) > 100_000:
            logger.warning("Number of trees is large, this may take a while.")
        scores = []
        for i in indices:
            model = self.get_tree(i)
            score = model.score(X, y)
            scores.append(score)
        return np.array(scores)

    def score_all(self, X, y):
        """ Calculate the scores of all models in the model set.

        Args:
            X (np.ndarray): Feature matrix.

        Returns:
            np.ndarray: Array of scores for each model.
        """
        logger = logging.getLogger("RsetWrapper.score_all")
        logger.info("Scoring %s trees", self.ntrees())
        if self.ntrees() > 100_000:
            logger.warning("Number of trees is large, this may take a while.")
        scores = []
        for i in range(self.ntrees()):
            model = self.get_tree(i)
            score = model.score(X, y)
            scores.append(score)
        return np.array(scores)

    def score(self, X, y, idx) -> float:
        """ Calculate the score of the model at the given index.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            idx (int): Index of the model to evaluate.

        Returns:
            float: Score of the model.
        """
        if idx is None:
            model = self.get_optimal_tree()
        else:
            model = self.get_tree(idx)
        return model.score(X, y)

    def score_many_trees(self, X, y, rng, max_tree: int=100_000) -> list:
        """ Calculate the scores of many models up to a maximum number.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            rng (generator): Random number generator.
            max_tree (int, optional): Maximum number of trees to evaluate. Defaults to 100_000.

        Returns:
            list: List of scores and their corresponding indices.
        """
        scores = []
        ntrees = self.ntrees()
        if ntrees == max_tree:
            indices = np.arange(ntrees)
        else:
            indices = np.sort(rng.choice(ntrees, min(ntrees,max_tree), replace=False))

        for idx in indices:
            score = self.score(X, y, idx)
            scores.append((score, idx))
        return np.array(scores)

    def select_fair_tree(self, X, y, rng, metrics):
        """ Select the best tree based on multiple fairness metrics.
        This function evaluates each tree based on the provided metrics and selects the best one.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            rng (generator): Random number generator.
            metrics (dict): Dictionary of metrics to evaluate the trees.
        """
        logger = logging.getLogger("RsetWrapper.select_fair_tree")
        ntrees = self.ntrees()
        max_tree = min(ntrees, 100_000)
        if ntrees == max_tree:
            indices = np.arange(ntrees)
        else:
            indices = np.sort(rng.choice(ntrees, min(ntrees,max_tree), replace=False))

        metrics["accuracy"] = accuracy_score
        metric_scores = defaultdict(list)
        for idx in indices:
            model = self.get_tree(idx)
            y_pred = model.predict(X)
            for metric_name, metric_func in metrics.items():
                score = metric_func(y, y_pred)
                metric_scores[metric_name].append((score, idx))

        acc_score = np.array(metric_scores["accuracy"])[:, 0]
        for metric_name in metrics.keys():
            if metric_name == "accuracy":
                continue
            metric_score = np.array(metric_scores[metric_name])
            objective = (1 - acc_score) + 0.3 * metric_score[:, 0]
            best_idx = np.argmin(objective)
            best_tree = metric_score[best_idx, 1]
            logger.info("Best Tree for %s: %s with score: %s", metric_name,
                    best_tree, metric_score[best_idx, 0])

            self.best_tree[metric_name] = best_tree

            if getattr(self, "best_tree_score", None) is None:
                self.best_tree_score = defaultdict(None)

            self.best_tree_score[metric_name] = metric_score[best_idx, 0]

    def select_tree(self, X, y, metrics, rng):
        """ Select the best tree based on a single metric.
        This function evaluates each tree based on the provided metric and data
                and selects the best one.
        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            metrics (str): Metric name of the given data.
            rng (generator): Random number generator.
        """
        logger = logging.getLogger("RsetWrapper.select_tree")
        max_tree = min(self.ntrees(), 100_000)
        scores = self.score_many_trees(X, y, rng, max_tree)
        sort_scores = np.argsort(scores[:, 0])[::-1]
        sorted_scores = scores[sort_scores]
        tree_indices = sorted_scores[:, 1]
        logger.info("Best Tree: %s with score: %s, metrics: %s", tree_indices[0],
                sorted_scores[0, 0], metrics)
        self.best_tree[metrics] = tree_indices[0]
        self.best_tree_score[metrics] = sorted_scores[0, 0]

    def tune(self, nested_cv, encoder=None) -> dict:
        """ Tune the model using nested cross-validation to find the best hyperparameters.
        This function evaluates different configurations of depth_budget and regularization
        and selects the best one based on the average score across folds.

        Args:
            nested_cv (list): Nested cross-validation data.
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            tuple: Best configuration and all scores.
        """
        logger = logging.getLogger("RsetWrapper Tune")
        depth_budgets = [2, 3, 4, 5]
        lambs = [0.02, 0.015, 0.01, 0.005]

        nested_cv = list(nested_cv)
        best_score = -float('inf')
        best_config = None
        all_scores = {}
        for depth_budget in depth_budgets:
            for lamb in lambs:
                logger.debug("Depth Budget: %s, Lambda: %s", depth_budget, lamb)
                config = {
                    'depth_budget': depth_budget,
                    'regularization': lamb,
                    'allow_small_reg': True,
                }

                scores = []
                for fold_id, (X, y), (X_val, y_val) in nested_cv:
                    if encoder is not None:
                        encoder.fit(X, y)
                        X, _ = encoder.transform(X)
                        X_val, _ = encoder.transform(X_val)
                    gosdt_model = GOSDTClassifier(**config)
                    gosdt_model.fit(X, y)

                    score = gosdt_model.score(X_val, y_val)
                    logger.debug("Fold %s Score: %s", fold_id, score)
                    scores.append(score)
                avg_score = np.mean(scores)
                logger.info("Depth Budget: %s, Lambda: %s, Avg Score: %s",
                        depth_budget, lamb, avg_score)
                all_scores[(depth_budget, lamb)] = scores

                if avg_score >= best_score:
                    best_score = avg_score
                    best_config = config
        logger.info("Best Config: %s with score: %s", best_config, best_score)
        return best_config, all_scores

    def get_tree(self, idx: int, encoder=None) -> Tree:
        """ Get the tree at the given index.

        Args:
            idx (int): Index of the tree.
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            Tree: Tree object at the given index.
        """
        tree = Tree(self.model.model_set.get_tree_at_idx_raw(idx), encoder)
        tree.n_features_ = self.n_features_
        tree.classes_ = self.classes_
        return tree

    def get_eval_tree(self, encoder) -> list:
        """ Get all evaluation trees.
        This function returns a list of trees for evaluation, including the optimal tree,
        the minimum leaf tree, and the maximum leaf tree. Also the special
        trees selected by different metrics.

        Args:
            encoder (encoder): Encoder for the data.

        Returns:
            list: List of evaluation trees.
        """
        trees = []
        for name, idx in self.special_tree.items():
            tree = self.get_tree(idx, encoder)
            trees.append([name, tree])
        for metric, idx in self.best_tree.items():
            tree = self.get_tree(idx, encoder)
            trees.append([f"{metric}_tree", tree])
        return trees

    def ntrees(self) -> int:
        """ Get the number of trees in the model.

        Returns:
            int: Number of trees in the model.
        """
        return self.model.model_set.model_count

    def get_random_tree(self, rng, encoder=None) -> Tree:
        """ Get a random tree from the model.
        This function selects a random tree from the model and returns it.

        Args:
            rng (generator): Random number generator.
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            Tree: Randomly selected tree.
        """
        return self.get_tree(rng.choice(self.ntrees()), encoder)

    def get_optimal_tree(self, encoder=None) -> Tree:
        """ Get the optimal tree from the model.
        This function selects the optimal tree based on the minimum objective value
        and returns it.

        Args:
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            Tree: Optimal tree.
        """
        return self.get_tree(self.special_tree["optimal_tree"], encoder)

    def get_min_leaf_tree(self, encoder=None) -> Tree:
        """ Get the minimum leaf tree from the model.

        Args:
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            Tree: Minimum leaf tree.
        """
        return self.get_tree(self.special_tree["min_leaf_tree"], encoder)

    def get_max_leaf_tree(self, encoder=None) -> Tree:
        """ Get the maximum leaf tree from the model.

        Args:
            encoder (encoder, optional): Encoder for the data. Defaults to None.

        Returns:
            Tree: Maximum leaf tree.
        """
        return self.get_tree(self.special_tree["max_leaf_tree"], encoder)

    def find_special_tree(self, rng) -> None:
        """ Find special trees based on different criteria.
        This function selects the optimal tree, the minimum leaf tree, and the maximum leaf tree
        and stores them in the special_tree dictionary. If there are multiple trees
        that meet the criteria, it randomly selects one.

        Args:
            rng (generator): Random number generator.
        """
        self.special_tree = {}
        self.special_tree["optimal_tree"] = rng.choice(self.get_optimal_tree_idx())
        self.special_tree["min_leaf_tree"] = rng.choice(self.get_min_leaf_tree_idx())
        self.special_tree["max_leaf_tree"] = rng.choice(self.get_max_leaf_tree_idx())

    def get_raw_idx_from_pointer_idx(self, pointer: str, i: int) -> int:
        """ Get the raw index from the pointer index.
        This function calculates the raw index of a tree based on its pointer
        and index within that pointer.

        Args:
            pointer (str): Pointer to the model set.
            i (int): Index within the model set.

        Returns:
            int: Raw index of the tree.
        """
        model_set = self.model.model_set
        count = 0
        for entry in model_set.available_metrics["metric_pointers"]:
            if entry == pointer:
                break
            count += model_set.get_model_set(entry)["count"]
        return count + i

    def get_optimal_tree_idx(self) -> list:
        """ Get the optimal tree index.
        This function finds the index of the optimal tree based on the minimum objective value
        and returns the raw indices of the trees that meet this criterion.

        Returns:
            list: List of raw indices of the optimal trees.
        """
        model_set = self.model.model_set
        min_obj_idx = np.argmin(np.array(model_set.available_metrics["metric_values"])[:,0])
        pointer = model_set.available_metrics["metric_pointers"][min_obj_idx]
        optimal_tree_model_set = model_set.storage[pointer]
        count = optimal_tree_model_set["count"]
        raw_idx = []
        for i in range(count):
            raw_idx.append(self.get_raw_idx_from_pointer_idx(pointer, i))
        return raw_idx

    def get_tree_num_leaves(self, idx: int) -> int:
        """ Get the number of leaves in the tree at the given index.

        Args:
            idx (int): Index of the tree.

        Returns:
            int: Number of leaves in the tree.
        """
        return self.model[idx].leaves()

    def get_min_leaf_tree_idx(self):
        """ Get the minimum leaf tree index.

        Returns:
            list: List of raw indices of the minimum leaf trees.
        """
        model_set = self.model.model_set
        metric_values = np.array(model_set.available_metrics["metric_values"])
        min_leaf_idx = np.where(metric_values[:,-1] == min(metric_values[:,-1]))[0]
        # if ties, break tie
        if len(min_leaf_idx) > 1:
            min_leaf_idx = min_leaf_idx[np.argmin(metric_values[min_leaf_idx,1])]
        else:
            min_leaf_idx = min_leaf_idx[0]
        pointer = model_set.available_metrics["metric_pointers"][min_leaf_idx]
        min_leaf_tree_model_set = model_set.storage[pointer]
        count = min_leaf_tree_model_set["count"]
        raw_idx = []
        for i in range(count):
            raw_idx.append(self.get_raw_idx_from_pointer_idx(pointer, i))
        return raw_idx

    def get_max_leaf_tree_idx(self):
        """ Get the maximum leaf tree index.
        This function finds the index of the maximum leaf tree based on the maximum objective value
        and returns the raw indices of the trees that meet this criterion.

        Returns:
            list: List of raw indices of the maximum leaf trees.
        """
        model_set = self.model.model_set
        metric_values = np.array(model_set.available_metrics["metric_values"])
        max_leaf_idx = np.where(metric_values[:,-1] == max(metric_values[:,-1]))[0]
        # if ties, break tie
        if len(max_leaf_idx) > 1:
            max_leaf_idx = max_leaf_idx[np.argmin(metric_values[max_leaf_idx,1])]
        else:
            max_leaf_idx = max_leaf_idx[0]
        pointer = model_set.available_metrics["metric_pointers"][max_leaf_idx]
        max_leaf_tree_model_set = model_set.storage[pointer]
        count = max_leaf_tree_model_set["count"]
        raw_idx = []
        for i in range(count):
            raw_idx.append(self.get_raw_idx_from_pointer_idx(pointer, i))
        return raw_idx
