import torch
from ar.utils import RESET, COLORS, safe_load_tensor
import os
from ar.config import LogicConfig
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
import joblib
from typing import Dict, List
from sklearn.tree import plot_tree as sk_plot_tree
from tqdm import tqdm
from datetime import datetime


class Concepts:
    def __init__(
        self,
        concepts: list[str],
        cache_dir: str,
        config: LogicConfig,
        verbose: bool = True,
    ):
        """Initialize the Concepts class.
        Args:
            concepts (list[str]): List of concept names to be processed.
            cache_dir (str): Directory where the concept data is cached.
            config (LogicConfig): Configuration object containing search parameters.
            verbose (bool): If True, print additional information during processing.
        """
        self.concepts = concepts
        self.search_concept_type = (
            config.search_concept_type
        )  #  "word", "sentence", "position"
        self.search_concept_token = (
            config.search_concept_token
        )  # Options: "all", "last"
        self.search_strategy = config.search_strategy  # 'top_k' or 'tree'
        # top_k search parameters
        self.search_top_k = config.search_top_k  # Number of top-k concepts to consider
        self.search_top_k_order = (
            config.search_top_k_order
        )  # 'unique_first' or 'original_order'
        # tree search parameters
        self.search_tree_depth = (
            config.search_tree_depth
        )  # Depth for tree search, None for full depth

        self.verbose = verbose

        self.cache_dir = cache_dir
        self.concept_forrest, self.concept_dict = {}, {}
        self.build_concepts()

    def build_concepts(self):
        # Load SAE indices
        # Initialize dictionary
        sae_latent_activations = {"positive": {}, "negative": {}}
        # Process each concept individually with proper error handling
        for concept in self.concepts:
            try:
                # Try loading the positive samples
                sae_latent_activations["positive"][concept] = safe_load_tensor(
                    f"{self.cache_dir}/{concept}_positive.pt"
                )
                # Only load negative if using tree search
                sae_latent_activations["negative"][concept] = safe_load_tensor(
                    f"{self.cache_dir}/{concept}_negative.pt"
                )
            except FileNotFoundError as e:
                # Now this correctly reports which concept had the missing file
                raise FileNotFoundError(
                    f"Could not find SAE indices for concept {concept} in {self.cache_dir}. Please run search first."
                )

        if self.search_strategy == "top_k":
            if self.verbose:
                print(
                    f"Greedy concept search (top-k concepts ={self.search_top_k}, {self.search_top_k_order})"
                )
            self.concept_dict = compute_top_k_activations(
                sae_latent_activations,
                k=self.search_top_k,
                search_top_k_order=self.search_top_k_order,
                verbose=self.verbose,
            )  # dict of concept to indices
            self.concept_forrest = None  # No decision tree for top_k search
        elif self.search_strategy == "tree":
            if self.search_tree_depth is None:
                raise ValueError(
                    "search_tree_depth must be set for tree search strategy."
                )
            self.concept_forrest, self.concept_dict = compute_concept_tree(
                sae_latent_activations,
                depth=self.search_tree_depth,
                cache_dir=f"{self.cache_dir}",
                verbose=self.verbose,
                grow_new_trees=False,
                balance_samples=False,
            )
        else:
            raise ValueError(
                f'Invalid search_strategy: {self.search_strategy}. Choose from "top_k" or "tree".'
            )

    def get_concept_names(self):
        """
        Get the list of concept names.

        Returns:
            list[str]: List of concept names.
        """
        return self.concepts

    def get_concept_dict(self):
        """
        Get the concept dictionary.

        Returns:
            dict: Dictionary containing concept names as keys and their corresponding indices and weights.
        """
        return self.concept_dict

    def get_concept_tensor(self):
        """Get the tensor representation of the concept indices.

        Returns:
            torch.Tensor: Tensor of shape (num_concepts, search_top_k) containing the indices of the top-k concepts.
            list[str]: List of concept names.
        """
        search_top_k = []
        for concept in self.concepts:
            if concept not in self.concept_dict:
                raise ValueError(
                    f"Concept {concept} not found in SAE features for concepts. Please check that the detector is set up correctly."
                )
            # get the indices of the concept
            search_top_k.append(self.concept_dict[concept]["indices"])
        return torch.tensor(
            search_top_k
        ), self.concepts  # shape: (num_concepts, search_top_k), (num_concepts,)

    def get_concept_indices(self, concept: str, top_k: int):
        """
        Get the indices for a specific concept.

        Args:
            concept (str): The name of the concept.
            top_k (int, optional): If specified, return only the top-k indices for the concept. Defaults to None.

        Returns:
            list: List of indices for the specified concept.
        """
        if concept in self.concept_dict:
            return (
                self.concept_dict[concept]["indices"][:top_k]
                if top_k is not None
                else self.concept_dict[concept]["indices"]
            )
        else:
            raise ValueError(
                f"Concept '{concept}' not found in the concept dictionary."
            )

    def get_concept_weights(self, concept: str, top_k: int):
        """
        Get the weights for a specific concept.

        Args:
            concept (str): The name of the concept.
            top_k (int, optional): If specified, return only the top-k weights for the concept. Defaults to None.

        Returns:
            list: List of weights for the specified concept.
        """
        if concept in self.concept_dict:
            return (
                self.concept_dict[concept]["weights"][:top_k]
                if top_k is not None
                else self.concept_dict[concept]["weights"]
            )
        else:
            raise ValueError(
                f"Concept '{concept}' not found in the concept dictionary."
            )

    def travers_concept_tree(
        self,
        activations: torch.Tensor,
        concept_names: list[str],
        return_probabilities: bool = True,
    ) -> torch.Tensor:
        """
        Traverse the decision trees to find concepts that are activated based on the provided activations.
        We return a mask of activated concepts or probabilities.

        Args:
            activations (torch.Tensor): Tensor of shape (num_samples, sae_latent_dimension) containing the activations for the concept.
            concept_names (list[str]): List of concept names to traverse. If None, we traverse all concepts in the concept forest.
            return_probabilities (bool): If True, return probabilities instead of binary predictions.

        Returns:
            torch.Tensor: Either boolean mask (num_samples, num_concepts) or probabilities (num_samples, num_concepts)
        """
        if not self.concept_forrest:
            raise ValueError(
                "Concept forest is empty. Please build the concept forest first using 'build_concepts()'."
            )
        concepts = self.concepts if concept_names is None else concept_names
        num_samples = activations.shape[0]

        # Convert activations to numpy for decision tree input
        if isinstance(activations, torch.Tensor):
            activations_np = activations.cpu().numpy()
        else:
            activations_np = activations

        # Initialize result tensor
        dtype = torch.float if return_probabilities else torch.bool
        result = torch.zeros((num_samples, len(concepts)), dtype=dtype)

        # Traverse each concept tree
        for i, concept in enumerate(concepts):
            if concept not in self.concept_forrest.keys():
                raise ValueError(
                    f"Concept '{concept}' not found in the concept forest."
                )

            # Get the decision tree for the concept
            concept_tree = self.concept_forrest[concept]

            if return_probabilities:
                # Get probabilities for class 1 (positive class)
                probabilities = concept_tree.predict_proba(activations_np)
                # debug check dims of probabilities
                # the sum across all classes should be 1
                if not np.allclose(probabilities.sum(axis=1), 1):
                    raise ValueError(
                        f"Probabilities for concept '{concept}' do not sum to 1 across classes. Check the decision tree model."
                    )

                # Extract probability for positive class (class 1)
                result[:, i] = torch.tensor(probabilities[:, 1], dtype=torch.float)
            else:
                # Use the tree to predict activations (1 = activated, 0 = not activated)
                predictions = concept_tree.predict(activations_np)
                result[:, i] = torch.tensor(predictions, dtype=torch.bool)

        return result


def compute_top_k_activations(
    activations: dict[str, dict[str, torch.Tensor]],
    k=10,
    search_top_k_order: str = "unique_first",
    verbose: bool = False,
) -> dict:
    """
    Compute top-k activations for each concept.

    Args:
        activations (dict[str, torch.Tensor]): Dictionary of mean activations for each concept (shape: (num_samples, sae_latent_dimension))
        k (int): Number of top activations to return

    Returns:
        dict: Structured dictionary containing all results organized by concept and category (all/unique/duplicate)
    """
    txt = (
        "-" * 20
        + "AL Gready Concepts"
        + "-" * 20
        + "\n"
        + f"Concept Overview (strategy: {search_top_k_order}):\n"
    )

    if not activations:
        return {}

    activations_pos = activations[
        "positive"
    ]  # shape: (num_samples, sae_latent_dimension)
    activations_neg = activations["negative"]
    concept_names = list(activations_pos.keys())
    sae_latent_dimension = activations_pos[concept_names[0]].shape[
        -1
    ]  # Get the latent dimension from the first concept

    mean_activations = torch.empty(
        (0, sae_latent_dimension)
    )  # Initialize mean activations for positive concepts

    # Compute mean activations for each concept
    # activations is a dict with concept names as keys and tensors of shape (num_samples, sae_latent_dimension) as values
    # We compute the mean activation for each concept across all samples
    # This will result in a tensor of shape (num_concepts, sae_latent_dimension)
    for concept in concept_names:
        p = activations_pos[concept].mean(
            dim=0
        )  # Compute mean activation for each concept (shape: (sae_latent_dimension))
        n = activations_neg[concept].mean(
            dim=0
        )  # Compute mean activation for each concept (shape: (sae_latent_dimension))
        # we substract the negative mean from the positive mean to get the concept activation, removing noise from the negative activations
        concept_activation = p - n
        # concept_activation = p
        mean_activations = (
            torch.cat((mean_activations, concept_activation.unsqueeze(0)), dim=0)
            if mean_activations.numel()
            else concept_activation.unsqueeze(0)
        )

    # Get top-k activations and indices
    top_values, top_indices = torch.topk(
        mean_activations, k, dim=1
    )  # (num_concepts, k), (num_concepts, k)

    # Reorder both indices and values by frequency of sharing
    top_indices_reordered = torch.zeros_like(top_indices)
    top_values_reordered = torch.zeros_like(top_values)

    # Count how many times each index appears across all concepts
    all_indices = top_indices.flatten()
    unique_indices, counts = torch.unique(all_indices, return_counts=True)

    # Create mapping from index to frequency count
    index_to_freq = {
        idx.item(): count.item() for idx, count in zip(unique_indices, counts)
    }

    # Reorder indices for each concept based on frequency
    for i in range(top_indices.shape[0]):  # For each concept
        concept_indices = top_indices[i].clone()
        concept_values = top_values[i].clone()

        # Skip if concept has no indices
        if concept_indices.numel() == 0:
            continue

        # Get frequency for each index in this concept
        frequencies = torch.tensor(
            [index_to_freq.get(idx.item(), 0) for idx in concept_indices],
            device=concept_indices.device,
        )

        # Sort indices by frequency (lower frequency = more unique = stays in front)
        _, sorted_idx = torch.sort(frequencies)
        top_indices_reordered[i] = concept_indices[sorted_idx]
        top_values_reordered[i] = concept_values[
            sorted_idx
        ]  # Reorder values to match indices

    # Initialize results structure
    results = {concept: {"indices": [], "values": []} for concept in activations_pos}
    con_width = max([len(c) for c in activations_pos]) + 5
    # Fill in the results
    for i, concept in enumerate(activations_pos):
        org_indices = top_indices[i].cpu().tolist()
        if search_top_k_order == "unique_first":
            indices = top_indices_reordered[i].cpu().tolist()
            weights = top_values_reordered[i].cpu().tolist()
        elif search_top_k_order == "unique_only":
            indices = [
                u if index_to_freq[u] == 1 else 0
                for u in top_indices_reordered[i].cpu().tolist()
            ]
            weights = [
                u if index_to_freq[w] == 1 else 0
                for u, w in zip(
                    top_values_reordered[i].cpu().tolist(),
                    top_indices_reordered[i].cpu().tolist(),
                )
            ]
        elif search_top_k_order == "original_order":
            indices = top_indices[i].cpu().tolist()
            weights = top_values[i].cpu().tolist()
        else:
            raise ValueError(f"Indexing strategy {search_top_k_order} not supported.")
            # Format indices with color codes: green for unique, red for duplicate
        concept_txt = f"{concept:<{con_width}} "
        max_width = len(str(sae_latent_dimension))
        colored_indices_org = []
        colored_indices_new = []
        for new_idx, orig_idx in zip(indices, org_indices):
            for idx, c_idx in zip(
                [new_idx, orig_idx], [colored_indices_new, colored_indices_org]
            ):
                if idx != 0 and index_to_freq[idx] == 1:  # Unique index
                    c_idx.append(
                        f"{COLORS['GREEN']}{idx:>{max_width}}{RESET}"
                    )  # Green for unique
                else:
                    c_idx.append(
                        f"{COLORS['ORANGE']}{idx:>{max_width}}{RESET}"
                    )  # Red for duplicate
        if colored_indices_org != colored_indices_new:
            concept_txt += (
                ", ".join(colored_indices_new)
                + "  <-  "
                + ", ".join(colored_indices_org)
            )
        else:
            concept_txt += ", ".join(colored_indices_org)
        txt += concept_txt + "\n"
        results[concept] = {"indices": indices, "weights": weights}
    txt += (
        f"Color legend: {COLORS['GREEN']}■{RESET} unique, {COLORS['ORANGE']}■{RESET} duplicate"
        + "\n"
    )
    if verbose:
        print(search_top_k_order)
        print(txt.strip())
    return results


def compute_concept_tree(
    activation_dict: dict[str, dict[str, torch.Tensor]],
    cache_dir: str,
    depth=10,
    grow_new_trees=False,
    balance_samples=True,
    verbose: bool = False,
) -> tuple[dict[str, DecisionTreeClassifier], dict]:
    """
    Compute decision trees for the given activations.
    """
    concept_names = list(activation_dict["positive"].keys())
    cache_dir = os.path.join(cache_dir, f"activation_trees/{depth}")

    # Try loading from cache first
    if not grow_new_trees:
        cached_result = _try_load_cached_trees(
            cache_dir, concept_names, activation_dict
        )
        if cached_result:
            return cached_result

    # Build new trees
    return _build_concept_trees(
        activation_dict,
        concept_names,
        cache_dir,
        depth,
        balance_samples=balance_samples,
        verbose=verbose,
    )


def _try_load_cached_trees(
    cache_dir: str,
    concept_names: list[str],
    activation_dict: dict,
    verbose: bool = False,
) -> tuple[dict, dict] | None:
    """Try to load cached decision trees."""
    # Use the same hash function as in _save_results
    import hashlib

    concept_hash = hashlib.md5("_".join(concept_names).encode()).hexdigest()
    cache_file = os.path.join(cache_dir, f"activations_{concept_hash}.pt")

    # For backward compatibility, also check the old naming pattern
    old_cache_file = os.path.join(
        cache_dir, f"activations_{''.join(concept_names)}.pt"
    ).replace(",", "_")

    # Check both the new hash-based filename and the old concatenated filename
    if os.path.exists(cache_dir) and os.path.exists(cache_file):
        if verbose:
            print(
                f"Loading cached concept trees from {cache_dir} (depth={cache_dir.split('/')[-1]})"
            )
        results = torch.load(cache_file)
        concept_forrest = {}
    elif os.path.exists(cache_dir) and os.path.exists(old_cache_file):
        if verbose:
            print(
                f"Loading cached concept trees from {cache_dir} using legacy filename (depth={cache_dir.split('/')[-1]})"
            )
        results = torch.load(old_cache_file)
        concept_forrest = {}
    else:
        return None

    for concept in concept_names:
        model_file = os.path.join(cache_dir, f"{concept}.joblib")
        if os.path.exists(model_file) and concept in activation_dict["positive"]:
            concept_forrest[concept] = joblib.load(model_file)
        else:
            break

    if len(concept_forrest) == len(concept_names):
        txt = f"Loaded {len(concept_forrest)} trees from cache for concepts: {', '.join(concept_names)}"
        if verbose:
            print(txt)
        return concept_forrest, results
    else:
        if verbose:
            print(
                f"Found {len(concept_forrest)}/{len(concept_names)} trees in cache, recomputing trees."
            )
        return None


def _build_concept_trees(
    activation_dict: dict,
    concept_names: list[str],
    cache_dir: str,
    depth: int,
    balance_samples: bool = True,
    verbose: bool = False,
) -> tuple[dict, dict]:
    """Build new decision trees for all concepts."""

    # Initialize data structures
    results = {
        concept: {"indices": [], "values": [], "impurity": []}
        for concept in concept_names
    }
    concept_forrest = {}
    metadata_lines = _initialize_metadata(depth, concept_names)

    # Prepare activation tensors
    # activations_t = torch.stack(list(activation_dict['positive'].values()))
    # activations_neg_t = torch.stack(list(activation_dict['negative'].values()))

    # Build trees for each concept
    txt_lines = ["-" * 20 + "AL Concept Trees" + "-" * 20]
    sae_latent_dimension = activation_dict["positive"][concept_names[0]].shape[1]
    max_width = len(str(sae_latent_dimension))
    con_width = max([len(c) for c in concept_names]) + 5

    for i in tqdm(
        range(len(concept_names)), desc=f"Building concept trees (depth={depth})"
    ):
        concept = concept_names[i]
        pos_activations = activation_dict["positive"][concept]
        neg_activations = activation_dict["negative"][concept]

        # Process samples for this concept
        tree_data = _prepare_tree_data(
            pos_activations, neg_activations, concept, balance_samples=balance_samples
        )

        # Train the decision tree
        clf = _train_decision_tree(tree_data, depth)
        concept_forrest[concept] = clf

        # Extract and store results
        result_dict = extract_tree_features(clf)
        results[concept] = {
            "indices": result_dict["feature_indices"],
            "values": result_dict["thresholds"],
            "impurity": result_dict["impurities"],
        }

        # Generate text output and metadata
        concept_txt = _format_concept_output(concept, result_dict, con_width, max_width)
        txt_lines.append(concept_txt)

        concept_metadata = _generate_concept_metadata(
            concept, tree_data, clf, result_dict
        )
        metadata_lines.extend(concept_metadata)

        # Save tree visualization and model
        _save_tree_artifacts(clf, concept, cache_dir, depth)

    # Save results and metadata
    txt = "\n".join(txt_lines).strip()
    if verbose:
        print(txt)

    # Save metadata
    _save_results(cache_dir, concept_names, results, metadata_lines)

    return concept_forrest, results


def _initialize_metadata(depth: int, concept_names: list[str]) -> list[str]:
    """Initialize metadata with general information."""
    return [
        "Decision Tree Metadata",
        f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
        f"Tree depth: {depth}",
        f"Number of concepts: {len(concept_names)}",
        f"Concepts: {', '.join(concept_names)}",
        "=" * 80,
    ]


def _prepare_tree_data(
    positive_activations: torch.Tensor,
    negative_activations: torch.Tensor,
    concept: str,
    balance_samples: bool = True,
) -> dict:
    """Prepare training data for a single concept's decision tree."""
    # Filter out zero samples
    positive_mask = positive_activations.sum(dim=-1) > 0
    negative_mask = negative_activations.sum(dim=-1) > 0

    positive_samples = positive_activations[positive_mask, :]
    negative_samples = negative_activations[negative_mask, :]

    original_pos_count = len(positive_samples)
    original_neg_count = len(negative_samples)

    # Balance samples
    if len(negative_samples) != len(positive_samples) and balance_samples:
        balance_samples = min(len(positive_samples), len(negative_samples))
        print(
            f"Warning: Concept '{concept}' - balancing samples to {balance_samples} each (was {original_pos_count} pos, {original_neg_count} neg)"
        )
        positive_samples = positive_samples[:balance_samples]
        negative_samples = negative_samples[:balance_samples]
    elif not balance_samples:
        print(
            f"Warning: Concept '{concept}' - using all samples without balancing (pos: {original_pos_count}, neg: {original_neg_count})"
        )

    # Prepare training data
    X = torch.cat([positive_samples, negative_samples], dim=0).cpu().numpy()
    y = [1] * len(positive_samples) + [0] * len(negative_samples)

    return {
        "X": X,
        "y": y,
        "positive_samples": positive_samples,
        "negative_samples": negative_samples,
        "original_pos_count": original_pos_count,
        "original_neg_count": original_neg_count,
    }


def _train_decision_tree(tree_data: dict, depth: int) -> DecisionTreeClassifier:
    """Train a decision tree on the prepared data."""
    clf = DecisionTreeClassifier(
        criterion="gini",
        max_depth=depth,
        class_weight="balanced",
    )
    return clf.fit(tree_data["X"], tree_data["y"])


def _format_concept_output(
    concept: str, result_dict: dict, con_width: int, max_width: int
) -> str:
    """Format the colored text output for a concept with separate lines for positive and negative features."""
    # Positive features line
    colored_indices = [
        f"{COLORS['GREEN']}{idx:>{max_width}} ({threshold:.2f}){RESET}"
        for idx, threshold in zip(
            result_dict["feature_indices"], result_dict["thresholds"]
        )
    ]
    pos_line = (
        f"{concept + ' (pos):':<{con_width + 6}} " + ", ".join(colored_indices)
        if colored_indices
        else f"{concept + ' (pos):':<{con_width + 6}} None"
    )

    # Negative features line
    negative_indices = [
        f"{COLORS['RED']}{idx:>{max_width}}{RESET}"
        for idx in result_dict["negative_features"]
    ]
    neg_line = (
        f"{concept + ' (neg):':<{con_width + 6}} " + ", ".join(negative_indices)
        if negative_indices
        else f"{concept + ' (neg):':<{con_width + 6}} None"
    )

    return pos_line + "\n" + neg_line


def _generate_concept_metadata(
    concept: str, tree_data: dict, clf: DecisionTreeClassifier, result_dict: dict
) -> list[str]:
    """Generate comprehensive metadata for a concept."""
    # Calculate performance metrics
    predictions = clf.predict(tree_data["X"])
    labels = tree_data["y"]

    tp = sum((predictions == 1) & (np.array(labels) == 1))
    tn = sum((predictions == 0) & (np.array(labels) == 0))
    fp = sum((predictions == 1) & (np.array(labels) == 0))
    fn = sum((predictions == 0) & (np.array(labels) == 1))

    precision = tp / max(1, tp + fp)
    recall = tp / max(1, tp + fn)
    f1 = 2 * precision * recall / max(1e-8, precision + recall)
    acc = (tp + tn) / max(1, len(tree_data["X"]))

    metadata = [
        f"\nConcept: {concept}",
        "-" * 40,
        f"Original samples - Positive: {tree_data['original_pos_count']}, Negative: {tree_data['original_neg_count']}",
        f"Used samples - Positive: {len(tree_data['positive_samples'])}, Negative: {len(tree_data['negative_samples'])}",
        f"Balance ratio: {len(tree_data['positive_samples']) / max(1, len(tree_data['X'])):.3f}",
        f"Tree depth: {clf.get_depth()}",
        f"Number of nodes: {clf.tree_.node_count}",
        f"Number of leaves: {clf.get_n_leaves()}",
        f"Training accuracy: {acc:.4f}",
        f"Root impurity: {clf.tree_.impurity[0]:.4f}",
        f"Confusion Matrix:",
        f"  TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}",
        f"  Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}",
        f"Positive features: {len(result_dict['feature_indices'])}",
        f"Negative features: {len(result_dict['negative_features'])}",
    ]
    print(
        f"Concept '{concept}' - TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {acc:.4f}"
    )

    # Add feature importances if available
    if clf.get_depth() > 1:
        importances = clf.feature_importances_
        top_indices = np.argsort(importances)[-5:][::-1]
        top_features = [(i, importances[i]) for i in top_indices if importances[i] > 0]
        if top_features:
            metadata.append("Top feature importances:")
            for idx, imp in top_features:
                metadata.append(f"  Feature {idx}: {imp:.4f}")

    return metadata


def _save_tree_artifacts(
    clf: DecisionTreeClassifier, concept: str, cache_dir: str, depth: int
):
    """Save tree visualization and model file."""
    os.makedirs(cache_dir, exist_ok=True)

    # Save tree visualization
    plot_tree(
        clf,
        class_names=["Negative", concept],
        filled=True,
        fontsize=10,
        save_dir=cache_dir,
        filename_prefix=f"activation_tree_{concept}_k{depth}",
        show=False,
        figsize=(12, 8),
    )

    # Save model
    model_path = os.path.join(cache_dir, f"{concept}.joblib")
    joblib.dump(clf, model_path)


def _save_results(
    cache_dir: str, concept_names: list[str], results: dict, metadata_lines: list[str]
):
    """Save results and metadata to files."""
    os.makedirs(cache_dir, exist_ok=True)

    # Use a hash of the concept names to create a shorter filename
    import hashlib

    concept_hash = hashlib.md5("_".join(concept_names).encode()).hexdigest()
    save_path = os.path.join(cache_dir, f"activations_{concept_hash}.pt")

    # Save the mapping of hash to concept names in the metadata
    metadata_lines.append(f"\nFilename hash: {concept_hash}")
    metadata_lines.append(f"Concepts included in this hash: {', '.join(concept_names)}")

    # Save results
    torch.save(results, save_path)

    # Save metadata with a shorter name as well
    metadata_file = os.path.join(cache_dir, f"tree_metadata_{concept_hash}.txt")
    with open(metadata_file, "w") as f:
        f.write("\n".join(metadata_lines))
    print(f"Metadata saved to: {metadata_file}")


def plot_tree(
    tree_model,
    class_names: List[str],
    filled=True,
    fontsize=10,
    save_dir=None,
    filename_prefix="tree",
    show=True,
    figsize=(12, 8),
):
    """
    Plot one or more decision trees using matplotlib.

    Parameters:
        tree_model: A single DecisionTreeClassifier or a list of them
        class_names (List[str]): List of class names
        filled (bool): Whether to fill the nodes with colors based on the class
        fontsize (int): Font size for the text in the plot
        save_dir (str, optional): Directory to save the plots to
        filename_prefix (str): Prefix for the saved filenames
        show (bool): Whether to display the plot interactively
        figsize (tuple): Figure size for each tree

    Returns:
        list: List of generated figure objects
    """
    # Convert single model to list for uniform processing
    models = tree_model if isinstance(tree_model, list) else [tree_model]
    figures = []

    for i, model in enumerate(models):
        # Create a new figure for each tree
        fig = plt.figure(figsize=figsize)

        # Plot the tree
        sk_plot_tree(model, class_names=class_names, filled=filled, fontsize=fontsize)

        # Add title if multiple trees
        if len(models) > 1:
            plt.title(f"Decision Tree {i + 1}")

        plt.tight_layout()
        figures.append(fig)

        # Save the figure if a directory is specified
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"{filename_prefix}_{i + 1}.png")
            fig.savefig(save_path, bbox_inches="tight", dpi=300)

    # Show plots if requested
    if show:
        plt.show()
    else:
        plt.close("all")
    return figures


def get_root_node(
    model_file: str = None, tree_model: DecisionTreeClassifier = None
) -> Dict[str, int]:
    """
    Load a trained DecisionTreeClassifier from a file and extract the root node information.

    Parameters:
        model_file (str): Path to the file containing the trained DecisionTreeClassifier.

    Returns:
        dict: Information about the root node, including feature index, threshold,
            and impurity.
    """
    if tree_model is None:
        # Load the model from the file
        tree_model = joblib.load(model_file)

    # Check if the model is fitted
    if not hasattr(tree_model, "tree_"):
        raise ValueError("The tree model is not fitted yet.")

    # Access the underlying tree structure
    tree = tree_model.tree_

    # Extract root node details
    root_index = 0
    root_feature = tree.feature[root_index]  # Feature index used for splitting
    root_threshold = tree.threshold[root_index]  # Threshold for the split
    root_impurity = tree.impurity[root_index]  # Impurity at the root node

    return {
        "feature_index": root_feature.item(),
        "threshold": root_threshold.item(),
        "impurity": root_impurity.item(),
    }


def extract_tree_features(tree_model: DecisionTreeClassifier) -> Dict[str, List]:
    """
    Extract features, thresholds, and impurities from a fitted DecisionTreeClassifier.
    We only consider nodes where the right child has more positive samples than negative samples.
    Thus we only consider features that positively contribute to the classification of the positive class.
    Args:
        tree_model (DecisionTreeClassifier): A fitted DecisionTreeClassifier model.
    """
    tree = tree_model.tree_
    features = []
    thresholds = []
    impurities = []
    negative_features = []

    # Process non-leaf nodes (leaf nodes have feature == -2)
    for i in range(tree.node_count):
        if tree.feature[i] != -2:  # Not a leaf node
            # Check the right child's sample distribution
            right_child = tree.children_right[i]
            if right_child != -1:  # Make sure there is a right child
                right_samples = tree.value[right_child][0]

                # Only include this node if right child has more positive (class 1) samples
                if right_samples[1] > right_samples[0]:
                    features.append(tree.feature[i].item())
                    thresholds.append(tree.threshold[i].item())
                    impurities.append(tree.impurity[i].item())
                else:
                    negative_features.append(tree.feature[i].item())
    # if not features:
    #     raise ValueError("No features found that positively contribute to the classification of the positive class. Consider adjusting your model or data.")

    return {
        "feature_indices": features,
        "thresholds": thresholds,
        "impurities": impurities,
        "negative_features": negative_features,
    }
