"""
    Metric class for evaluating models
"""
import logging
from collections import defaultdict
import yaml

import json
import numpy as np
import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score

from module.model import (RsetWrapper, Tree, KantchModel)
from module.utils import BinSequence
class Metrics:
    """ Metric Class
    """
    def __init__(self, config_file, rng=None):
        if rng is None:
            self.rng = np.random.default_rng()

        self.rng = rng
        self.config_file = config_file
        self.config = None
        with open(self.config_file, "r", encoding="utf-8") as file:
            configs = yaml.safe_load(file)
        self.config = configs["metrics"]
        self.metrics_mapper = {
            "train_eval": {
                "func": self.train_eval,
            },
            "test_eval": {
                "func": self.test_eval,
            },
        }

        self.metrics_eval_list = self.config.get("flags", {})
        assert len(self.metrics_eval_list) > 0, \
                "No metrics to evaluate. Please check the config file."

        #region Attributes that will be initialized/used later
        self.model = None
        self.eval_model = None
        self.encoder = None
        self.X_train = None
        self.y_train = None
        self.X_test = None
        self.y_test = None
        self.sensitive_features = None
        self.art_classifier = None
        self.X_member_train = None
        self.y_member_train = None
        self.X_nonmember_train = None
        self.y_nonmember_train = None
        self.X_eval = None
        self.y_eval = None
        self.eval_label = None
        self.kantch_adv_examples = None
        #endregion

    def eval_model_init(self, model, encoder=None) -> None:
        """ Initialize the evaluation model and encoder

        Args:
            model (model_wrapper): The model to be evaluated.
            encoder (encoder, optional): Encoder to process the data. Defaults to None.
        """
        self.model = model
        self.eval_model = model
        self.encoder = encoder

    def metric_data_init(self, X_train, y_train, X_test, y_test, sensitive_features=None):
        """ Initialize the data for evaluation

        Args:
            X_train (np.ndarray): Feature matrix for training.
            y_train (np.ndarray): Target vector for training.
            X_test (np.ndarray): Feature matrix for testing.
        ndarray): Target vector for testing.
            sensitive_features (np.ndarray, optional): Sensitive features
                                    for fairness evaluation. Defaults to None.
        """
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.sensitive_features = sensitive_features

    def calculate_score(self, y_true, y_pred) -> dict:
        """ Helper function to calculate various scores

        Args:
            y_true (np.ndarray): True labels
            y_pred (np.ndarray): Predicted labels

        Returns:
            dict: Dictionary of scores
        """
        scores = {}

        scores["accuracy"] = accuracy_score(y_true, y_pred)
        #TODO Think of cleaner way to do this
        # scores["recall"] = recall_score(y_true, y_pred)
        # scores["precision"] = precision_score(y_true, y_pred, zero_division=0)
        # scores["f1-score"] = f1_score(y_true, y_pred)
        # scores["roc_auc"] = roc_auc_score(y_true, y_pred)
        # scores["FPR"] = 1 - recall_score(y_true, y_pred, pos_label=0)
        return scores

    def evaluation(self) -> dict:
        """ Function to evaluate the model using various metrics

        Returns:
            dict: Dictionary of evaluation results
        """
        logger = logging.getLogger("Metrics.evaluation")
        eval_result = {}

        if isinstance(self.model, RsetWrapper):
            eval_result.update(self.density_visualization())

            models = self.model.get_eval_tree(self.encoder)
            
            for name, tree in models:
                if self.encoder is not None:
                    X_train = self.encoder.transform(self.X_train.copy())
                else:
                    X_train = self.X_train.copy()
                tree.generate_proba(X_train, self.y_train)
                logger.info("Tree %s with json %s", name, json.dumps(tree.to_json(), indent=4))
        else:
            models = [[None, self.model]]

        logger.info("Metrics list: %s", self.metrics_eval_list)
        for metric in self.metrics_eval_list:
            logger.info("Running metric: %s", metric)

            self.metrics_mapper[metric].get("init", lambda: None)()

            for model_name, model in models:
                logger.info("Running model: %s", model_name)
                self.eval_model = model
                ret = self.metrics_mapper[metric]["func"]()
                for ret_name, value in ret.items():
                    if model_name is not None:
                        eval_result[model_name + "_" + metric + "_" + ret_name] = value
                    else:
                        eval_result[metric + "_" + ret_name] = value
            
            self.metrics_mapper[metric].get("clean", lambda: None)()
        # self.exploration(eval_result)
        return eval_result

    

    #region Standard
    def train_eval(self) -> dict:
        """ Train score evaluation

        Returns:
            dict: Dictionary of scores
        """
        if self.encoder is not None:
            X_train = self.encoder.transform(self.X_train)
        else:
            X_train = self.X_train.copy()
        y_pred = self.eval_model.predict(X_train)
        return self.calculate_score(self.y_train, y_pred)

    def test_eval(self) -> dict:
        """ Test score evaluation

        Returns:
            dict: Dictionary of scores
        """
        if self.encoder is not None:
            X_test = self.encoder.transform(self.X_test)
        else:
            X_test = self.X_test.copy()
        y_pred = self.eval_model.predict(X_test)
        return self.calculate_score(self.y_test, y_pred)

    #endregion

    ###
    # Desnity Visualization
    ###
    def density_visualization(self) -> dict:
        """ Density visualization

        Returns:
            dict: Dictionary of scores
        """
        logger = logging.getLogger("Metrics.density_visualization")
        logger.info("Density visualization")
        assert isinstance(self.model, RsetWrapper), "Density visualization only works with RSET"
        X_adv = self.kantch_example_gen(self.model.get_optimal_tree(self.encoder),
                                                    self.X_test, self.y_test)
        X_test = self.encoder.transform(self.X_test) \
                if self.encoder is not None else self.X_test

        X_adv = self.encoder.transform(X_adv) \
                if self.encoder is not None else X_adv

        if self.model.ntrees() > 500_000:
            indices = self.rng.choice(self.model.ntrees(), 500_000, replace=False)
        else:
            indices = np.arange(self.model.ntrees())

        predictions = self.model.predict_many(X_test, indices)
        predictions = np.array([BinSequence(0, 0).from_array(pred).x for pred in predictions])
        
        adv_scores = self.model.score_many(X_adv, self.y_test, indices)
        score_idx = np.argmax(adv_scores)
        self.model.best_tree["best_Kantch_test_tree"] = score_idx

        return {"kantch_test_scores": adv_scores, "test_preds": predictions, "indices": indices}
    #endregion

    ###
    #region Robustness
    ###
    def kantch_evasion_init(self):
        """ Robustness metrics initialization
        Initializes the robustness metrics by generating adversarial examples
        and storing them for later evaluation.
        """
        logger = logging.getLogger("Metrics.kantch_evasion_init")
        logger.info("Robustness metrics init")
        target_model = []
        if isinstance(self.model, RsetWrapper):
            target_model.append(["optimal_tree", self.model.get_optimal_tree(self.encoder)])
            target_model.append(["best_kantch_test_tree", self.model.get_tree(self.model.best_tree["best_Kantch_test_tree"], self.encoder)])
            # target_model.append(["random_tree", self.model.get_random_tree(self.rng, self.encoder)])
            # target_model.append(["min_leaf_tree", self.model.get_min_leaf_tree(self.encoder)])
            # target_model.append(["max_leaf_tree", self.model.get_max_leaf_tree(self.encoder)])
        else:
            target_model.append([None, self.eval_model])

        self.kantch_model_init(target_model)

    def kantch_evasion_clean(self):
        self.kantch_adv_examples = None

    def kantch_model_gen(self, model):
        """ Generate a kantch model from the given model. Converting to a
        tool model that can generate adversarial examples.

        Args:
            model (model_wrapper): The model to be converted.

        Returns:
            model (KantchModel): The converted model.
        """
        if isinstance(model, Tree):
            model_json = model.to_json()
            model = KantchModel(model_json, 2)
        else:
            assert False, f"Invalid model class: {type(model)}"
        return model

    def kantch_example_gen(self, model, X, y, epsilon=0.1, order=np.inf, log_stats=False):
        """ Generate adversarial examples using the kantch model.

        Args:
            model (model_wrapper): The model to be used for generating adversarial examples.
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            epsilon (float, optional): Kantch Epsilon. Defaults to 0.1.
            order (_type_, optional): Order for distance calculation. Defaults to np.inf.

        Returns:
            _type_: _description_
        """
        model = self.kantch_model_gen(model)
        options = { "epsilon": epsilon , "disable_progress_bar": True, "logging": log_stats}
        return model.adversarial_examples(X.copy(), y.copy(), options=options, order=order)

    def kantch_model_init(self, target_model_list):
        """ Kantch model initialization and generate adversarial examples.

        Args:
            target_model_list (list): List of target models for generating adversarial examples.
        """
        logger = logging.getLogger("Metrics.kantch_model_init")
        epsilon = self.config["params"]["kantch_evasion_attack"]["epsilon"]
        order = self.config["params"]["kantch_evasion_attack"]["order"]

        logger = logging.getLogger("Metrics.kantch_model_init")
        logger.info("Kantch model init with epsilon %s", epsilon)
        self.kantch_adv_examples = []

        for model_name, target_model in target_model_list:
            #TODO This need to be removed
            model = self.kantch_model_gen(target_model)
            bounding_boxes, _ = model.get_bounding_boxes(self.X_test)

            s = "\n" 
            for idx, bounding_box in enumerate(bounding_boxes):
                s += f"Bounding box {idx}:\n"
                for feat_idx, box in enumerate(bounding_box):
                    if box[0] == -np.inf and box[1] == np.inf:
                        continue
                    s += f"\tFeature {feat_idx}: ({box[0]}, {box[1]})\n"
                s += "\n"
            logger.info("Bounding box for %s %s", model_name, s)

            for X, y, data_name in [(self.X_train, self.y_train, "train"),
                                            (self.X_test, self.y_test, "test")]:
                logger.info("Generating adversarial examples for %s on model %s", data_name, model_name)
                X_adv = self.kantch_example_gen(target_model, X, y, epsilon, order, True)
                examples_name = f"on_{data_name}"
                if model_name is not None:
                    examples_name += f"_from_{model_name}"
                self.kantch_adv_examples.append([examples_name, X_adv, y])
