import torch
from typing import List, Optional, Tuple
from ar.utils import get_weight_vector
import torch.nn.functional as F
from ar.model.concept import Concepts


class ALConceptDetector:
    """
    Class to manage Activation Reasoning rules and steering for language models.

    This component handles:
    1. Tracking active concepts
    2. Applying rules based on concept activation
    3.

    Attributes:
        rules (dict): Dictionary of rules and their associated concepts
        activate_concept_mask (torch.Tensor): Mask of activated concepts (shape: (batch, sequence_length, num_concepts))
        activate_concept_confidence (torch.Tensor): Confidence for each concept activation (shape: (batch, sequence_length, num_concepts))
        verbose (bool): Flag to control verbosity of logging
    """

    def __init__(
        self,
        verbose: bool = True,
    ):
        """
        Initialize the logic component.

        Args:
            concepts (List[str]): List of concepts to be activated
            verbose (bool): Flag to control verbosity of logging
        """

        # Input logging
        self.verbose = verbose

        # Concepts
        self.concepts = None  # type: Optional[Concepts]
        self.concept_names = []  # List of concept names

        # State variables
        self.activate_concept_mask = torch.empty(
            0
        )  # shape is (batch, sequence_length, num_concepts), a mask that can be mapped on the input
        self.activate_concept_confidence = torch.empty(
            0
        )  # shape is (batch, sequence_length, num_concepts), the confidence for each concept activation

    def set_up(self, concepts: Concepts, verbose: Optional[bool] = False):
        """
        Set up the detector component with the indices and weights of the concepts. search_top_k are the indices of the concepts in the latent space.
        This method is called after the logic component is initialized.

        Args:
            concepts (Concepts): Concepts object containing the concept indices and weights
            verbose (bool): Flag to control verbosity of logging
        """
        self.verbose = verbose if verbose is not None else self.verbose
        self.concepts = concepts
        self.concept_names = self.concepts.get_concept_names()  # List of concept names

        self.detect = (
            self.tree_detect_concepts
            if self.concepts.search_strategy == "tree"
            else self.detect_concepts
        )

    def reset(self):
        """
        Resets state variables to empty tensors.
        """
        self.activate_concept_mask = torch.empty(0)
        self.activate_concept_confidence = torch.empty(0)

    def module_ready(self):
        """
        Check if the logic component is ready for use.

        Returns:
            bool: True if the component is ready, False otherwise
        """
        # if indices are empty, the module is not ready and should be set up
        return self.concepts is not None

    def tree_detect_concepts(
        self,
        latent_activations,
        attention_mask,
        detection_top_k_output: int,
        detection_top_k_concepts: int,
        detection_threshold: float,
        detection_allow_multi: bool,
    ) -> Tuple[torch.Tensor, List[str]]:
        """
        Identifies and adds new concepts from model outputs. And activates rules based on the new concepts.

        Args:
            latent_activations (torch.Tensor): Latent activations from the SAE model (batch, sequence_length, latent_dim)
            attention_mask (torch.Tensor): Attention mask for the input sequence (batch, sequence_length)
            detection_top_k_output (int): Number of top activations to consider
            detection_top_k_concepts (int): Number of top concepts to consider
            detection_threshold (float): Threshold for concept activation
            detection_allow_multi (bool): Flag to allow multiple activations for a single token
        Returns:
            torch.Tensor: Mask of activated concepts (shape: (batch, sequence_length, num_concepts))
        """
        if self.concepts is None:
            raise ValueError(
                "Concepts have not been set up. Please call 'set_up()' with a Concepts object before using this method."
            )
        b_size, seq_len, latent_dim = (
            latent_activations.shape
        )  # shape is (batch, sequence_length, latent_dim)
        latent_activations = latent_activations.cpu()

        # only keep the top k activations
        top_activations, top_output_indices = torch.topk(
            latent_activations, detection_top_k_output, dim=-1
        )  # top_activations shape is (batch, sequence_length, detection_top_k_output), (batch, sequence_length, detection_top_k_output)
        latent_activations = torch.zeros_like(
            latent_activations, dtype=torch.float
        ).scatter(
            -1, top_output_indices, top_activations
        )  # shape is (batch, sequence_length, latent_dim)

        # normalize the latent activations across the sequence length
        # latent_activations = debug_top_latent_activations.expand(-1, seq_len, -1).contiguous() # Now this works, shape is (batch, sequence_length, latent_dim)
        # debug_top_latent_activations = debug_top_latent_activations.topk(5, dim=-1)  # shape is (batch, detection_top_k_concepts)
        # for b in range(debug_top_latent_activations.values.shape[0]):
        #     print(f"Detect Debug: Top 5 activations for batch {b}: {debug_top_latent_activations.values[b]}, indices: {debug_top_latent_activations.indices[b]}")

        # if self.concepts.search_concept_type == 'sentence':
        #     sum_activations = attention_mask.sum(dim=1).to(latent_activations.device)  # (batch, sequence_length), sum of attention mask across the sequence length
        #     seq_latent_activations = latent_activations.sum(dim=1)  # (batch, latent_dim)
        #     latent_activations = seq_latent_activations / sum_activations.clamp(min=1e-6).unsqueeze(-1)  # (batch, latent_dim), normalize by the number of tokens in the sequence to avoid division by zero
        # else:
        # we try to just take the mean
        latent_activations = latent_activations.view(
            b_size * seq_len, latent_dim
        )  # Reshape to (batch * sequence_length, latent_dim)

        # In your detection function:
        tree_preds_prob = self.concepts.travers_concept_tree(
            latent_activations, self.concept_names
        )  # shape is (batch, num_concepts) or (batch * sequence_length, num_concepts)

        # depending on shape we need to extend or reshape the tree_preds to match the sequence length
        if tree_preds_prob.shape[0] == b_size:
            activate_concept_prob = (
                tree_preds_prob.unsqueeze(1).expand(-1, seq_len, -1).contiguous()
            )  # Reshape to (batch, sequence_length, num_concepts)
            # apply the attention mask to the activate_concept_mask
        else:
            activate_concept_prob = tree_preds_prob.view(
                b_size, seq_len, -1
            )  # Reshape back to (batch, sequence_length, num_concepts)

        activate_concept_prob = activate_concept_prob * attention_mask.unsqueeze(-1).to(
            tree_preds_prob.device
        )  # shape is (batch, sequence_length, num_concepts)

        # apply the detection threshold
        # print(f"Detect Debug: activate_concept_prob mean: {activate_concept_prob.mean().item()}, max: {activate_concept_prob.max().item()}, min: {activate_concept_prob.min().item()}")
        activate_concept_prob[activate_concept_prob < detection_threshold] = (
            0.0  # Set values below threshold to 0
        )
        # print(f"Detect Debug: activate_concept_prob after thresholding mean: {activate_concept_prob.mean().item()}, max: {activate_concept_prob.max().item()}, min: {activate_concept_prob.min().item()}")

        # create confidence tensor
        confidence = activate_concept_prob.unsqueeze(
            -1
        ).float()  # shape is (batch, sequence_length, num_concepts, 1)
        confidence = (
            confidence.expand(-1, -1, -1, detection_top_k_concepts)
            / detection_top_k_concepts
        )  # shape is (batch, sequence_length, num_concepts, detection_top_k_concepts)

        # convert probabilities to binary mask
        activate_concept_mask = (
            activate_concept_prob > 0
        ).int()  # shape is (batch, sequence_length, num_concepts)

        if self.activate_concept_mask.numel() == 0:
            # First call, just use the current tensors directly
            self.activate_concept_mask = activate_concept_mask
            self.activate_concept_confidence = confidence
        else:
            # Check if batch size has changed
            if self.activate_concept_mask.shape[0] != activate_concept_mask.shape[0]:
                raise ValueError(
                    f"Batch size has changed from {self.activate_concept_mask.shape[0]} to {activate_concept_mask.shape[0]}. Please use same batch size for all calls to detect_concepts. or reset the module."
                )
            else:
                # Normal case, concatenate along sequence dimension
                self.activate_concept_mask = torch.cat(
                    (self.activate_concept_mask, activate_concept_mask), dim=1
                )
                self.activate_concept_confidence = torch.cat(
                    (self.activate_concept_confidence, confidence), dim=1
                )

        return self.activate_concept_mask, self.concept_names

    def detect_concepts(
        self,
        latent_activations,
        attention_mask,
        detection_top_k_output: int,
        detection_top_k_concepts: int,
        detection_threshold: float,
        detection_allow_multi: bool,
    ) -> Tuple[torch.Tensor, List[str]]:
        """
        Identifies and adds new concepts from model outputs. And activates rules based on the new concepts.

        Args:
            latent_activations (torch.Tensor): Latent activations from the SAE model (batch, sequence_length, latent_dim)
            detection_top_k_output (int): Number of top activations to consider
            detection_top_k_concepts (int): Number of top concepts to consider
            detection_threshold (float): Threshold for concept activation
            detection_allow_multi (bool): Flag to allow multiple activations for a single token
        Returns:
            torch.Tensor: Mask of activated concepts (shape: (batch, sequence_length, num_concepts))
            list[str]: List of concept names
        """
        if self.concepts is None:
            raise ValueError(
                "Concepts have not been set up. Please call 'set_up()' with a Concepts object before using this method."
            )
        concept_tensor, _ = (
            self.concepts.get_concept_tensor()
        )  # shape is (num_concepts, steering_top_k_rule), indices of the concepts in the latent space
        if detection_top_k_concepts > concept_tensor.shape[1]:
            raise ValueError(
                f"Top k concepts ({detection_top_k_concepts}) for is greater than the number of indices available ({concept_tensor.shape[1]}). Please check the concept indices."
            )

        latent_activations = (
            latent_activations.cpu()
        )  # shape is (batch, sequence_length, latent_dim)
        top_activations, top_output_indices = torch.topk(
            latent_activations, detection_top_k_output, dim=-1
        )  # top_activations shape is (batch, sequence_length, detection_top_k_output)

        # shape is (1, 1, 1, num_concepts, detection_top_k_concepts)
        detection_top_k_concepts_indices = (
            concept_tensor[:, :detection_top_k_concepts]
            .unsqueeze(0)
            .unsqueeze(0)
            .unsqueeze(0)
            .cuda()
        )

        # Shape (batch, sequence_length, detection_top_k_output, 1, 1)
        reshaped_output_indices = top_output_indices.unsqueeze(-1).unsqueeze(-1).cuda()

        # Shape: (batch, sequence_length, detection_top_k_output, 1, 1)
        reshaped_activation = top_activations.unsqueeze(-1).unsqueeze(-1).cuda()

        # Check which output indices match with which concept indices
        # Result shape: (batch, sequence_length, detection_top_k_output, num_concepts, detection_top_k_concepts)
        matches = reshaped_output_indices == detection_top_k_concepts_indices

        # shape is (batch, sequence_length, detection_top_k_output, num_concepts, detection_top_k_concepts)
        concept_activations = reshaped_activation * matches.float()

        # shape is (batch, sequence_length, num_concepts, detection_top_k_concepts)
        activations = concept_activations.sum(dim=2)

        # add weighted activations based on the index position
        weights = get_weight_vector(
            detection_top_k_concepts,
            steering_weighting_function="log_decay",
            mean=1,
            std=None,
        ).to(activations.device)
        activations = (
            activations * weights
        )  # shape is (batch, sequence_length, num_concepts, detection_top_k_concepts)

        # print(activations[:, :, -1, :])
        # print(torch.topk(activations.sum(dim=-1)[:, :, -2], 100, dim=-1))
        # apply the activation threshold
        activate_concept_mask = (
            activations.sum(dim=-1) > detection_threshold
        ).int()  # shape is (batch, sequence_length, num_concepts)
        activations = activations * activate_concept_mask.unsqueeze(
            -1
        )  # shape is (batch, sequence_length, num_concepts, detection_top_k_concepts)

        # if multi activation is not allowed, we only consider the concept with the highest activation strength
        if (
            not detection_allow_multi
            and activations.any()
            and activate_concept_mask.shape[2] > 1
        ):
            c_activations = activations.sum(
                dim=-1
            )  # shape is (batch, sequence_length, num_concepts)
            # print(f'removing multi activations')
            active_concepts = torch.argmax(
                c_activations, dim=-1
            )  # shape is (batch, sequence_length)
            # convert indices into a one-hot encoded sparse matrix
            _acm = F.one_hot(
                active_concepts, num_classes=activate_concept_mask.shape[2]
            ).float()  # shape is (batch, sequence_length, num_concepts)
            # we need to apply the mask to the active concept mask to avoid argmaxing on 0 activations
            activate_concept_mask = (
                _acm * activate_concept_mask
            )  # shape is (batch, sequence_length, num_concepts)
            # update the activations
            activations = (
                activations * activate_concept_mask.unsqueeze(-1)
            )  # shape is (batch, sequence_length, num_concepts, detection_top_k_concepts)

        if self.activate_concept_mask.numel() == 0:
            # First call, just use the current tensors directly
            self.activate_concept_mask = activate_concept_mask
            self.activate_concept_confidence = activations
        else:
            # Check if batch size has changed
            if self.activate_concept_mask.shape[0] != activate_concept_mask.shape[0]:
                raise ValueError(
                    f"Batch size has changed from {self.activate_concept_mask.shape[0]} to {activate_concept_mask.shape[0]}. Please use same batch size for all calls to detect_concepts. or reset the module."
                )
            else:
                # Normal case, concatenate along sequence dimension
                # print(f'Concatenating activate_concept_mask and activate_concept_confidence with shapes {self.activate_concept_mask.shape} and {activate_concept_mask.shape}')
                self.activate_concept_mask = torch.cat(
                    (self.activate_concept_mask, activate_concept_mask), dim=1
                )
                self.activate_concept_confidence = torch.cat(
                    (self.activate_concept_confidence, activations), dim=1
                )

    def get_concept_mask(self):
        return self.activate_concept_mask

    def get_concept_names(self):
        return self.concept_names

    def get_weighted_concepts(self) -> torch.Tensor:
        """
        We get the probabilities of the concepts based on the following:
        1. The strength of the activation of the concept indices in the latent space (confidence of activation)
        2. The position of the concept indice (represnting the importance of the indice for the concept). We can us different weighting functions to weight the indices.

        Args:
            weighting (str): Weighting function to use for the concept indices. Options are 'linear_decay', 'exponential_decay', 'uniform', 'linear', 'exponential'.
        Returns:
            torch.Tensor: Probabilities of the concepts (shape: (batch, sequence_length, num_concepts))
        """
        # activate_concept_confidence -- shape (batch, sequence_length, num_concepts, detection_top_k_concepts)
        top_k = self.activate_concept_confidence.shape[-1]
        weight_vector = get_weight_vector(
            top_k, steering_weighting_function="log_decay", mean=1, std=None
        ).to(self.activate_concept_confidence.device)  # Shape: (top_k)
        weighted = (
            self.activate_concept_confidence @ weight_vector
        )  # shape is (batch, sequence_length, num_concepts)
        # convert confidence to probabilities
        return weighted

    def extract_concepts(
        self,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str], torch.Tensor]:
        """
        Get the currently activated concepts and their positions in the sequence, and their batch ids.
        Returns:
            batch_ids (torch.Tensor): Batch ids (shape total number of activated concepts)
            seq_ids (torch.Tensor): Sequence ids (shape total number of activated concepts)
            concept_ids (torch.Tensor): Concept ids (shape total number of activated concepts)
            concept_names (List of str): Currently active concepts (shape total number of activated concepts)
            confidences (torch.Tensor): Confidence for each concept activation (shape total number of activated concepts)
        """
        if self.concepts is None:
            raise ValueError(
                "Concepts have not been set up. Please call 'set_up()' with a Concepts object before using this method."
            )
        activate_concept_confidence = self.get_weighted_concepts().unsqueeze(
            -1
        )  # shape is (batch, sequence_length, num_concepts)
        concept_mask = (
            self.activate_concept_mask
        )  # shape is (batch, sequence_length, num_concepts)
        batch_ids, seq_ids, concept_ids = torch.nonzero(concept_mask, as_tuple=True)
        confidences = activate_concept_confidence[
            batch_ids, seq_ids, concept_ids, :
        ].sum(-1)  # shape is (total number of activated concepts)
        # convert concept ids to concept
        concept_names = [self.concept_names[concept_id] for concept_id in concept_ids]
        # print(batch_ids, seq_ids, concept_ids, concept_names, confidences)
        return batch_ids, seq_ids, concept_ids, concept_names, confidences

    def get_activations(self) -> torch.Tensor:
        """
        Get the current activation matrix for the concepts.
        Returns:
            torch.Tensor: Activation matrix (shape: (batch, sequence_length, num_concepts, detection_top_k_concepts))
        """
        return self.activate_concept_confidence

    def get_concepts(self) -> List[str]:
        """
        Get the list of concepts.
        Returns:
            List[str]: List of concepts
        """
        return self.concept_names if self.concepts else []
